TextVectorization for Text Nested List

Hello everyone,

I want to reproduce this example, but I would like the input_data to be a Text Nested List instead. In other words:

text_dataset = tf.data.Dataset.from_tensor_slices(["foo", "bar", "baz"])
max_features = 5000
max_len = 4

# -------- EXPERIMENTAL ADDENDUM -----------------
def split_on_comma(input_data):
    return tf.strings.split(input_data, sep=", ")
# ------------------------------------------------

vectorize_layer = tf.keras.layers.TextVectorization(
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=max_len,
    split=split_on_comma
)

with tf.device("/CPU:0"):
    vectorize_layer.adapt(text_dataset.batch(64))

inputs = tf.keras.Input(shape=(1,), dtype=tf.string)
x = vectorize_layer(inputs)
model = tf.keras.Model(inputs=inputs, outputs=x)

# input_data = [["foo qux bar"], ["qux baz"]] # Original input, I DO NOT want this...
input_data = [["foo", "qux", "bar"], ["qux baz"]] # ...I WANT THIS AS INPUT INSTEAD
input_data = [[', '.join(x)] for x in input_data] # MY EXPERIMENTAL WORKAROUND 
model.predict(input_data)

The thing is that I have a dataset where one of the features consists of a list of strings, with different length each, that do not share any context with each other, and I want to use it as training input for some Model with the Functional API. Splitting it in different columns (OHE) is out of question, as this feature is highly sparse (~7K different classes).

I am a bit concerned that maybe my implementation is not the best neither correct, because I am getting different results; for instance, this is the example’s original output:

array([[2,1,4,0],
       [1,3,0,0]])

And what I am getting is:

array([[1,0,0,0],
       [1,0,0,0]])

Finally, I am not quite sure how (if?) Ragged Tensors would be of use with TextVectorization layer (again, the goal is building a Model using the Functional API). I have found some examples, but not exactly focused towards TextVectorization. Plus, I get that TextVectorization can process “one string per sample”, but if so, what would be other suggestions to what I am aiming here?

Have someone have had this problem before? How did you tackled it?

Thank you very much.

1 Like
array([[1,0,0,0],
       [1,0,0,0]])

Somehow your split is failing to split and it’s all coming in as one token. The zeros are padding, the ones are probably <UNKNOWN>. I looked at this a bit, and it’s because the default standardize is “lower and strip punctuation” → it’s removing the commas.

I am not quite sure how (if?) Ragged Tensors would be of use with TextVectorization layer

Ragged Tensors are exactly what you’re looking for here. Your split function returns ragged tensors.
If you split before passing to the TextVectorization layer (the result will be ragged), just tell the TextVectorization layer not to split by passing split=None.

the goal is building a Model using the Functional API

If you have a ragged input to a functional model use keras.Input(..., ragged=True).

If you’re not using any other features of TextVectorization then StringLookup may be a good alternative since it doesn’t have all the bells and whistles.

Hello @markdaoust ,

Thanks for your time and comments.

Sorry for not adding my feedback, I solved this issue by 2022-01-26. Here is my implementation:

# This data will be adapted on 'TextVectorization' layer, later on
list_of_lists = [
    ["Python", "Java", "C++"],
    ["Netbeans", "Django"],
    ["Many", "Other", "List", "Of", "Lists", "Here"]
]

# 'TextVectorization' layer gets initialized, and 'list_of_lists' is also adapted
text_dataset = list(set([item.lower() for sublist in list_of_lists for item in sublist]))
text_dataset = tf.data.Dataset.from_tensor_slices(text_dataset) # 'text_dataset' is re-used here (!)
vectorize_layer  = tf.keras.layers.TextVectorization(
    split=None,  # TO AVOID LIST SPLITTING
    ragged=True, # TO WORK WITH RAGGED TENSORS ('tf.ragged') LATER ON
    name="skill_vectorizer_layer"
)
with tf.device("/CPU:0"):
    vectorize_layer.adapt(text_dataset.batch(64))
    
# Preprocessing layers, including 'vectorize_layer'
inputs = tf.keras.layers.Input(
    shape=(None,),    # We do NOT know the length of the input lists
    dtype=tf.string,  # We do know that the list elements, are strings
    ragged=True,      # We do know the inputs will be ragged tensors
    name="skill_input"
)
x = vectorize_layer(inputs)
model = tf.keras.Model(inputs=inputs, outputs=x)

# NEW components in which preprocessing will be applied
new_list_of_lists = [
    ["Pyhton", "Django", "Computer Science"],
    ["Wait", "Some", "Of", "These", "Items", "Have", "Not", "Been", "Seen", "Before!"]
]

# Results
new_ragged_input = tf.ragged.constant(new_list_of_lists) # Input MUST be a 'tf.ragged.constant'
model.predict(new_ragged_input)

Which gets me the following output (dimensions match, no zero padding this time):

<tf.RaggedTensor [[1, 11, 1], [1, 1, 4, 1, 1, 1, 1, 1, 1, 1]]>

Knowing the output makes sense, the preprocessing layers can now be connected to other layers (Embeddings and so on).

ADDENDUM: For future readers, another aspect that might be challenging and somehow related to this topic (well, as you will see it is not that difficult but the information on how to do it, is rather scarce), is obtaining a tf.data.Dataset when you are dealing with ragged tensors. One way to do it, goes as follows:

  • If your model has a SINGLE INPUT / SINGLE OUTPUT:
dataset = dataset = tf.data.Dataset.from_tensor_slices(
    (
        tf.ragged.constant(list_of_lists), # This is my 'list of lists' input, used in this example
        # Other inputs, must be compatiple with 'tf.data.Dataset.from_tensor_slices'
        # Your output, must be compatiple with 'tf.data.Dataset.from_tensor_slices'.
    )
)
  • If your model has MULTIPLE INPUTS / SINGLE - MULTIPLE OUTPUTS:
dataset = dataset = tf.data.Dataset.from_tensor_slices(
    (
        (
            tf.ragged.constant(list_of_lists), # This is my 'list of lists' input, used in this example
            # Other inputs, must be compatiple with 'tf.data.Dataset.from_tensor_slices'
        ),
        # Your output, must be compatiple with 'tf.data.Dataset.from_tensor_slices'.
        # NOTE: If you have multiple outputs, remember to group them inside a tuple, as the input in this code snippet
    )
)

And that would be it. Of course, checking the official info is highly suggested.

I consider this case solved.

Thank you.

1 Like