How to use tf.data.Dataset.scan

Can anyone explain in much more detail how to effectively use the scan() method of a tensorflow dataset. The documentation explanation and usage does not seem very clear to me.

Hi @Atia, The tf.data.Dataset.scan allows you to apply a stateful transformation to a dataset. lets take a small dataset

dataset = tf.data.Dataset.range(1,10)

The scan function takes 2 arguments, initial_state and scan_func. so lets define those

initial_state = tf.constant(0, dtype=tf.int64)
def scan_func(state,element):
  transformed_element = element + state
  new_state = state + element
  return new_state, transformed_element

Here the scan_function will add state value to the first element present in the dataset (for 1st iteration: element + state : 1+0 and the state value will be incremented by the element value i.e state + element : 0+1). For the next iteration the state value will be 1.

dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)

Now the dataset will be updated to : [1, 3, 6, 10, 15, 21, 28, 36, 45] . Thank You.