Help distributing predictions with Spark

Hi,

I would like to distribute my Tensorflow prediction using a Spark RDD containing one tfrecord file per partition. Here is the snipped of my code

df_score_files = df_score_files.withColumn("part_idx", f.monotonically_increasing_id()).repartition("part_idx")
    n_part_idx, part_idx_map = create_id_map(df_score_files, id_name="part_idx")
    part_idx_map = spark.sparkContext.broadcast(part_idx_map)
    predictions_rdd = df_score_files.rdd.map(lambda row: (part_idx_map.value[row.part_idx], row)) \
            .partitionBy(n_part_idx, lambda x: x) \
            .mapPartitions(distribute_score)

The function distribute_score is applied to all partitions

def distribute_score(iterator):
    score_files = [ row[1].asDict()["score_files"] for row in iterator]
    dfpred, _ = predict_evaluate1(files_list="gs://b_meta_algo_tmp/nrpc_deep_wide_tfrecords_v1/selector=test_standard_scaler/account=trivago/split=pred/yyyy_mm_dd=2021-05-30/part-00019-e6d924ff-6e7b-40ab-ae9f-9d2c833f92c2-c000.tfrecord.gz",
                    estimator=estimator,
                    checkpoint_path=None,
                    features_dict=features_dict,
                    input_tfrecords_compression="GZIP")
    return dfpred.values.tolist()

The scoring itself is made by the function predict_evaluate1 which works perfectly in non-distributed mode. However when I try the distributed version I get this error

return parse_example_v2(serialized, features, example_names, name)
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py”, line 201, in wrapper
return target(*args, **kwargs)
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/parsing_ops.py”, line 309, in parse_example_v2
params = _ParseOpParams.from_features(features, [
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/parsing_config.py”, line 451, in from_features
raise ValueError(“Unsupported %s %s.” %
ValueError: Unsupported FixedLenFeature FixedLenFeature(shape=(), dtype=tf.string, default_value=None).

I suspect this should be connected with my input_fn and in particular to my parser defined into it

def input_fn1(
    file_names: List[str],
    batch_size: int,
    shuffle_buffer_size: int,
    features_dict: Dict[str, tf.io.FixedLenFeature],
    num_epochs: int = 1,
    n_examples: int = None,
    compression: str = "",
    parallel_files_reads: int = 10,
    deterministic_interleave: bool = False,
) -> tf.data.TFRecordDataset:

    """
    Input function to provide data for tf.Estimator
    Args:
        file_names (List[str]): List of files containing features
        batch_size (int): Batch size
        shuffle_buffer_size (int): Size of the buffer to be shuffled
        features_dict (Dict[str, tf.io.FixedLenFeature]): Dictionary with features
        num_epochs (int): Number of epochs to run
        n_examples (int): Number of random examples to use for train or prediction
        compression (str): Compression codec of tfrecords
        parallel_files_reads (int): Number of files to be read in parallel
        deterministic_interleave (bool): Use deterministic=True in dataset.interleave
    """
    
    #def parser(record, features_dict):
    #    parsed_features = tf.io.parse_single_example(record, features_dict)
    #    return parsed_features, parsed_features['target']
    
    files = tf.data.Dataset.list_files(file_names)
    dataset = files.interleave(
        lambda x: tf.data.TFRecordDataset(
            x, compression_type=compression, num_parallel_reads=parallel_files_reads
        ).prefetch(buffer_size=tf.data.experimental.AUTOTUNE),
        cycle_length=tf.data.experimental.AUTOTUNE,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        deterministic=deterministic_interleave,
    )

    if n_examples is not None:
        dataset = dataset.take(n_examples)

    dataset = dataset.map(lambda x: tf.io.parse_single_example(x, features_dict), 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    if num_epochs > 1:
        dataset = dataset.cache()
    return dataset