Help distributing predictions with Spark


I would like to distribute my Tensorflow prediction using a Spark RDD containing one tfrecord file per partition. Here is the snipped of my code

df_score_files = df_score_files.withColumn("part_idx", f.monotonically_increasing_id()).repartition("part_idx")
    n_part_idx, part_idx_map = create_id_map(df_score_files, id_name="part_idx")
    part_idx_map = spark.sparkContext.broadcast(part_idx_map)
    predictions_rdd = row: (part_idx_map.value[row.part_idx], row)) \
            .partitionBy(n_part_idx, lambda x: x) \

The function distribute_score is applied to all partitions

def distribute_score(iterator):
    score_files = [ row[1].asDict()["score_files"] for row in iterator]
    dfpred, _ = predict_evaluate1(files_list="gs://b_meta_algo_tmp/nrpc_deep_wide_tfrecords_v1/selector=test_standard_scaler/account=trivago/split=pred/yyyy_mm_dd=2021-05-30/part-00019-e6d924ff-6e7b-40ab-ae9f-9d2c833f92c2-c000.tfrecord.gz",
    return dfpred.values.tolist()

The scoring itself is made by the function predict_evaluate1 which works perfectly in non-distributed mode. However when I try the distributed version I get this error

return parse_example_v2(serialized, features, example_names, name)
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/”, line 201, in wrapper
return target(*args, **kwargs)
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/”, line 309, in parse_example_v2
params = _ParseOpParams.from_features(features, [
File “/opt/conda/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/”, line 451, in from_features
raise ValueError(“Unsupported %s %s.” %
ValueError: Unsupported FixedLenFeature FixedLenFeature(shape=(), dtype=tf.string, default_value=None).

I suspect this should be connected with my input_fn and in particular to my parser defined into it

def input_fn1(
    file_names: List[str],
    batch_size: int,
    shuffle_buffer_size: int,
    features_dict: Dict[str,],
    num_epochs: int = 1,
    n_examples: int = None,
    compression: str = "",
    parallel_files_reads: int = 10,
    deterministic_interleave: bool = False,
) ->

    Input function to provide data for tf.Estimator
        file_names (List[str]): List of files containing features
        batch_size (int): Batch size
        shuffle_buffer_size (int): Size of the buffer to be shuffled
        features_dict (Dict[str,]): Dictionary with features
        num_epochs (int): Number of epochs to run
        n_examples (int): Number of random examples to use for train or prediction
        compression (str): Compression codec of tfrecords
        parallel_files_reads (int): Number of files to be read in parallel
        deterministic_interleave (bool): Use deterministic=True in dataset.interleave
    #def parser(record, features_dict):
    #    parsed_features =, features_dict)
    #    return parsed_features, parsed_features['target']
    files =
    dataset = files.interleave(
        lambda x:
            x, compression_type=compression, num_parallel_reads=parallel_files_reads

    if n_examples is not None:
        dataset = dataset.take(n_examples)

    dataset = x:, features_dict), 
    if shuffle_buffer_size > 0:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(

    if num_epochs > 1:
        dataset = dataset.cache()
    return dataset