TF-agents mismatched trajectory spec

Below is my simplified implementation of the tf_agents environment and my attempt to tie it to TFUniformReplayBuffer and DynamicEpisodeDriver. It is not going to do anything fancy and will be used only in my tests. The episode is supposed to run for 10 steps with gradually increasing value of observations (which is irrelevant) and generating actions using random policy.

class TestEnv(PyEnvironment):
    def __init__(self):
        n_actions = 2
        self._n_observations = 3
        self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=n_actions - 1, name="act")
        self._observation_spec = ArraySpec(shape=(self._n_observations,), dtype=np.float64, name="obs")
        self._idx = 0

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._idx = 0
        observation = np.asarray([self._idx] * self._n_observations, dtype=np.float64)
        return time_step.restart(observation)

    def _step(self, action):
        self._idx += 1
        observation = np.asarray([self._idx] * self._n_observations, dtype=np.float64)
        if self._idx >= 10:
            return time_step.termination(observation, reward=0)
        return time_step.transition(observation, reward=1, discount=1)


class TestExperienceReply(object):
    def __init__(self, policy, environment):
        self._policy = policy
        self._environment = environment
        self._replay_buffer = self._make_replay_buffer(self._environment)
        observers = [self._replay_buffer.add_batch]
        self._driver = DynamicEpisodeDriver(self._environment, policy, observers, num_episodes=1)

    def _make_replay_buffer(self, tf_env):
        time_step_spec = tf_env.time_step_spec()
        action_spec = tf_env.action_spec()
        action_step_spec = policy_step.PolicyStep(action_spec, (), tensor_spec.TensorSpec((), tf.int32))
        trajectory_spec = trajectory.from_transition(time_step_spec, action_step_spec, time_step_spec)
        return TFUniformReplayBuffer(data_spec=trajectory_spec, batch_size=tf_env.batch_size)

    def collect(self):
        _, _ = self._driver.run()

env = TestEnv()
tf_env = tf_py_environment.TFPyEnvironment(env)
policy = RandomTFPolicy(time_step_spec=tf_env.time_step_spec(), action_spec=tf_env.action_spec())
experience_replay = TestExperienceReply(policy=policy, environment=tf_env)
experience_replay.collect()

When I try to run it, I get following error triggered within collect method:

E       ValueError: The two structures do not match:
E       (...)
E       Values:
E         Trajectory(
E       {'action': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
E        'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
E        'next_step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>,
E        'observation': <tf.Tensor: shape=(1, 3), dtype=float64, numpy=array([[0., 0., 0.]])>,
E        'policy_info': (),
E        'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
E        'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>})
E       vs.
E         Trajectory(
E       {'action': BoundedTensorSpec(shape=(), dtype=tf.int32, name='act', minimum=array(0, dtype=int32), maximum=array(1, dtype=int32)),
E        'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
E        'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
E        'observation': TensorSpec(shape=(3,), dtype=tf.float64, name='obs'),
E        'policy_info': TensorSpec(shape=(), dtype=tf.int32, name=None),
E        'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
E        'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')}).

Looks like I am somehow missing an extra tensor dimension somewhere but I really can not figure out where and why. Does anyone see any problem with the code above?

1 Like