scan method is neat. It allows you to iterate over the elements of a dataset and at the same time carry over a state.
Just think about a simple dataset that produces the values
1, 2, 3, 4, 5.
With the map function, you can loop over every single element of this dataset and apply a transformation.
dataset.map(lambda x: x*2) # 2, 4, 6, 8, 10
scan instead, you can carry over some information from the previous iterations.
So, for example, if you want to sum up the previous element, you can use the
initial_state = tf.constant(0)
def scan_fun(old_state, input_element):
new_state = input_element
output_element = input_element + old_state
return new_state, output_element
dataset.scan(initial_state, scan_func) # 1 +0 , 2 + 1, 3 + 2, ...
So you carry over the next iteration, the
old_state, and every time you iterate over a new
input_element you can generate a new
new_state (that becomes the
old_state input for the next iteration), and produce (as output) the
scan in several solutions (which is super helpful), I suggest you read all the articles and search inside the code how I used it. I hope it’s helpful for you to understand a little bit more about how to use this great feature.