TensorFlow Decision Forests with TFX (model serving and evaluation)

Hi all, I’m trying to incorporate a GBT model using TFDF into an existing TFX project. Thus far, we are able to successfully train a GBT using TFDF inside a trainer component, but we are having issues creating an Evaluator stage for this model.

I believe our issue lies with the serving functions we’re creating for this model. We are following the same technique used previously for normal TF/Keras NNs. Here is the error that occurs during the evaluation stage:

RuntimeError: tensorflow.python.framework.errors_impl.InvalidArgumentError: slice index 0 of dimension 1 out of bounds.
            [[{{node StatefulPartitionedCall/gradient_boosted_trees_model/StatefulPartitionedCall/strided_slice_3}}]] [Op:__inference_signature_wrapper_1194552]

Here is the code for creating the serving functions & signatures:

def _get_serve_tf_examples_fn(model, tf_transform_output, transform=False):
    """Returns a function that parses a serialized tf.Example and applies TFT."""

    model.tft_layer = tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
        """Returns the output to be used in the serving signature."""
        raw_feature_spec = tf_transform_output.raw_feature_spec().copy()
        feature_spec = { k: v for k, v in raw_feature_spec.items() if k in FEATURES }
        label_spec = raw_feature_spec.get(_LABEL_KEY) or {}

        if transform:
            parsed_features_with_label = tf.io.parse_example(
                serialized_tf_examples, { **feature_spec, **label_spec }
            )

            return model.tft_layer(parsed_features_with_label)

        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
        transformed_features = model.tft_layer(parsed_features)

        return model(transformed_features)

    return serve_tf_examples_fn

...

# Code for creating signatures & saving model
 _serving_default = _get_serve_tf_examples_fn(model, tf_transform_output) \
          .get_concrete_function(
              tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
          )

_transform = _get_serve_tf_examples_fn(model, tf_transform_output, transform=True) \
          .get_concrete_function(
              tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
          )

signatures = { 'serving_default': _serving_default, 'transform': _transform }

model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

Has anyone successfully used TFDF with TFX?

1 Like

Couldn’t find a way to edit my original topic, but I have also inspected the saved model.

I ran:
$ saved_model_cli show --dir <path_to_model> --all

There was this error reported at the end of this output:

FileNotFoundError: Op type not registered 'SimpleMLInferenceOpWithHandle' in binary running on craftsman. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
 If trying to load on a different device from the computational device, consider using setting the `experimental_io_device` option on tf.saved_model.LoadOptions to the io_device such as '/job:localhost'.

Here is the remainder of the output that came prior to the error:


MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['examples'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: serving_default_examples:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall_10:0
  Method name is: tensorflow/serving/predict

signature_def['transform']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['examples'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: transform_examples:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:3
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:7
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:11
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:15
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:16
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:17
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:21
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:25
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:29
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:33
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:37
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:38
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:39
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:43
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:47
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:51
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:52
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:53
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:57
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:61
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:65
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:69
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:73
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:77
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:81
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:82
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:83
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:87
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:91
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:95
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:99
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:100
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:101
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:105
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:109
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:113
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:117
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:121
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1)
        name: StatefulPartitionedCall_11:125
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:126
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:127
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:131
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:135
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:139
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: StatefulPartitionedCall_11:140
    outputs['<redacted>'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: StatefulPartitionedCall_11:141
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:145
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:149
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:153
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: StatefulPartitionedCall_11:157
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1, -1)
        name: 
    outputs['<redacted>'] tensor_info:
        dtype: DT_INT64
        shape: (-1)
        name: StatefulPartitionedCall_11:161
  Method name is: tensorflow/serving/predict
1 Like

Hi Ryan,

Thanks for the report.

Yes, we have demos of TF-DF in TFX, but we have yet to publish it :).

Regarding your first error, can you replace the following:

model(transformed_features)

by the following:

transformed_features_rank2 = tf.nest.map_structure(lambda v : tf.expand_dims(v,axis=1), transformed_features)
model(transformed_features_rank2)

If this works, this indicates that the tensor features are of rank 1, but that the model is exported with input signature of rank 2. Keras runs a tf.expand_dim(axis=1) during training, predict or evaluate calls, but it does not export it to the model’s call (i.e. model(features)). This can cause this type of issues. Neural net users rarely use rank 1 features, so this issue is rarely met. We are working on a better solution there.

An alternative solution would be to inject the expand_dim in the feature transformation.

If this does not solve the issue, can you:

  • Print the model signature with: model.call.pretty_printed_concrete_signatures().
  • Print the transformed_features shape i.e. print(transformed_features).

Regarding the FileNotFoundError: Op error. This is due to saved_model_cli being compiled without the TF-DF ops. TF-DF ops are not yet canonical, and therefore, they are not yet injected in pre-compiled binaries. Until this is done, models need to be inspected “manually” (unless you have the courage to re-compile those binaries :slight_smile: ).

A similar error can be met when loading a TF-DF SavedModel in side binaries (e.g. the TFX evaluator [depending on the version of TFX] or the C++ TF Serving infrastructure). In this case, you will have to do the op injection:

Cheers,
M.

4 Likes

Hi Mathieu,

Thank you for taking the time to help me.

Yes, we have demos of TF-DF in TFX, but we have yet to publish it :).

I’m looking forward to these demos!

I have been working on this on and off for the past few days and am still running into some issues. I am able to successfully run the entire pipeline when all features are Tensor objects, but have issues when they are SparseTensor objects despite transforming them. I am still investigating this, but wanted to check if it is known that sparse tensors need to be handled specially.

Thank you,
Ryan

Hello Ryan,

Sparse tensors are supported for all the semantics. They have some specific logic though:

  • (Except for Categorical-set features), each indexed value of a sparse tensor is treated as a separate individual features e.g. value[:,0], value[:,1], etc.
  • Non specified sparse values are detected as missing feature values.
  • The maximum shape of a sparse tensor should be set if the sparse tensor contains multiple features. If the maximum shape is not set (i.e. value.shape[1] is None), it is assumed to be one i.e. there is only one feature. You can see the list of individual features at the start of the training logs.

If you have specific issues with them, don’t hesiate to share it (even if you solved the issue). This kind of feedback is great to improve the API, errors/warnings and documentation :).

1 Like

Hi,

I’ve put some effort in serving TF-DF models with TFX serving.
https://blog.ml6.eu/serving-decision-forests-with-tensorflow-b447ea4fc81c

There is an open PR for the code, but meanwhile the code/docker image is available at Docker Hub.

All feedback is welcome on the image, but also on the PR. Happy to help getting the necessary parts in TFX serving.

Pieter

5 Likes

This is really cool :slight_smile: .

2 Likes

+1 to that, very cool!

2 Likes

Hi @Mathieu, earlier you wrote:

I’m trying to get a TFX pipeline to run in Vertex AI as a custom training job (TFX 1.5.0, TFDF 0.2.1, TF 2.7.0) but it’s failing in the Evaluator with this error:

tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'SimpleMLLoadModelFromPathWithHandle' in binary running on ...

It runs locally just fine but not in Vertex AI. Could you elaborate on how I could import TF-DF in the Evaluator component?

Thank you,
Ed

Hi Ed,

The error indicates that the “tensorflow custom op” (i.e. the compiled c++ code that runs the model) is not loaded in memory.

The (Open Sourced) TFX evaluator is a simple python script. Therefore, I suspect that importing tf-df in the evaluator could solve the issue.

Can you try to add the line import tensorflow_decision_forests at the top of your evaluator python script? For example, right after the import tensorflow_model_analysis as tfma.

Cheers,
M.

Hi Mathieu, thanks for getting back to me. I’m using the taxi and penguin tfx templates as the basis for the pipeline that’s including the TFDF library. As such, I have a pipeline.py file where the Evaluator is used among other TFX components.
I’ve added the import line you suggested but I don’t think it has any bearing on the pipeline execution that’s happening on the Vertex platform (I run the runner_kubeflow_v2.py locally that submits a job to the Vertex platform). It seems like I need to get whatever executor is running on Vertex to execute the above import statement but I just don’t know where …

Okay. So you need to figure a way to make the import in the evaluation worker :).

I have little experience with VertexAI, and with a quick look up, it does not seems to be much documentation about “custom ops” and “custom evaluator component” (which is the official name for this kind of things). Maybe, this section about custom pip package would make sense to further explore.

An in the mean time, I’ve asked for help to a VertexAI+TFDF user. ETA>1 week.

Ok, thanks very much!

Hi Ed,

As a sanity check, could you provide the TFMA version in your image? The TFX evaluator uses
TFMA as the underlying module, and TF-DF support was added in TFMA 0.34.0 [github]. For reference, TFX 1.5.0 uses TFMA 0.36.0.

Please also check the worker log of the dataflow job created by the TFX evaluator (Dataflow → Jobs → click on the job → click ‘Show’ at the bottom next to Logs → WORKER LOGS → enter ‘tensorflow_decision_forests’ in filter), and find the line imported tensorflow_decision_forests or tensorflow_decision_forests is not available: No module named 'tensorflow_decision_forests'. Please let us know whether TF-DF is properly imported.

If everything checks out but the problem persists, we can try a more general workaround for custom ops.

Best,
Alister

Hi @aliao, thanks for the suggestion but the TFX Evaluator is being run with Beam (as I initiated this job on Vertex) so I can’t check the logs as you’ve indicated. There doesn’t appear to be any mention of tensorflow_decision_forests in the logs that I can see (can be viewed in this gist: Evaluator Logs + pipeline.json · GitHub).
I can see in the JSON file generated by the tfx.dsl.Pipeline API call that it’s specifying the TFX version as 1.5.0 and the image to use for the Evaluator as gcr.io/tfx-oss-public/tfx:1.5.0. I included this file in the above Gist.

Hi @Ed_Park and @Mathieu,

We are having the exact same problem with running a TFDF model in vertex AI with an evaluator component. We ran into problems with the trainer component as well because TFDF was not included in the default TFX container used by vertex AI: gcr.io/tfx-oss-public/tfx:1.5.0. We were able to solve the problem by creating a custom container that included the TFDF library, but in order to get vertex ai to use the custom container, we had to use the google cloud ai platform trainer component instead: tfx.v1.extensions.google_cloud_ai_platform.Trainer  |  TFX  |  TensorFlow

Currently, the only google cloud ai platform TFX components available are the BulkInferrer, pusher, trainer, and tuner. So it looks like we’ll have to reverse engineer a custom evaluator component that can use a custom container.

Let me know if you guys have any other suggestions, because we don’t have much experience with custom TFX components, and it will probably take a significant investment.

Ah, it looks like we may just need to upgrade to TFX 1.6.0 because the release documents say they added experimental support for TFDF:

Hi @anon43767231, I still don’t have a solution for getting the Evaluator component to work so I’ve unfortunately had to leave out of my pipeline. I haven’t tried with 1.6.0 yet so perhaps it may just start to work. Let us know if 1.6.0 works for you and I’ll do likewise when I try it out.

I gave it a shot yesterday, and it looks like TFDF is still not included in the gcr image. Here’s a list of the tensorflow libraries installed in the tfx 1.6.1 gcr image:

Step 1/2 : FROM gcr.io/tfx-oss-public/tfx:1.6.1
 ---> 52c7ed90ce2d
Step 2/2 : RUN python -m pip list | grep -E "tfx|tensorflow"
 ---> Running in 21981c437fe0
tensorflow                            2.7.0
tensorflow-cloud                      0.1.16
tensorflow-data-validation            1.6.0
tensorflow-datasets                   4.4.0
tensorflow-estimator                  2.7.0
tensorflow-hub                        0.12.0
tensorflow-io                         0.24.0
tensorflow-io-gcs-filesystem          0.24.0
tensorflow-metadata                   1.6.0
tensorflow-model-analysis             0.37.0
tensorflow-probability                0.14.1
tensorflow-serving-api                2.7.0
tensorflow-transform                  1.6.0
tfx                                   1.6.1
tfx-bsl                               1.6.0
1 Like

It looks like we will have to create a custom docker image that includes tfdf. This is easy enough. The problem is getting vertex AI to use the custom image instead of the gcr docker image. We were able to do this with the trainer using the google cloud ai platform trainer component. This accepts a workerpoolspec field in the custom_config dictionary where you can set the URI of your docker image.

The problem with the evaluator is that there’s not a google cloud ai platform version, so we would have to create a custom tfx component that runs the evaluator in a separate custom training job in vertex AI using the new docker image.