CategoricalDQN_agent not working with a Mask?

Hello, I’m currently training a RL agent with tf_agents in order to play a card game. As it worked well with a DQN agent, I tried to improve my results by using a categorical DQN agent.
This card game includes valid/invalid actions, so I provided a mask to the agent.

My observation spec is like this :(and so is my input_tensor_spec of my categorical_q_network)
{'observation': BoundedTensorSpec(shape=(85,), dtype=tf.int32, name='observation', minimum=array(-100), maximum=array(100)), 'valid_actions': TensorSpec(shape=(91,), dtype=tf.bool, name='valid_actions')}

However, when the agent wants to start training and take the first timestep as an input, distribution method of “categorical_q_policy.py” is called and in particular this block of code :

network_observation = time_step.observation
    observation_and_action_constraint_splitter = (
        self.observation_and_action_constraint_splitter)

    if observation_and_action_constraint_splitter is not None:
      network_observation, mask = (
          observation_and_action_constraint_splitter(network_observation))


    q_logits, policy_state = self._q_network(
        network_observation, step_type=time_step.step_type,
        network_state=policy_state)

As i have a mask, I provide an observation_and_action_constraint_splitter, and the spec of my network_observation after is the if condition is now like this :

tf.Tensor(
[[  1   1   1  30  30   5   4   0   0   0   1   0   0   0   0   0   0   1
    0   0   7   9   5   7   7   7   2   3   2   7   7   7   2   3   2 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99]], shape=(1, 85), dtype=int32)

The problem is, when the logits are trying to be instantiated, there is a structure error : inputs (network_observation) is now a single tensor while input_tensor_spec expects a dict with {‘observation’: …, ‘valid_actions’:…} , raising the following error :

ValueError: <tf_agents.networks.categorical_q_network.CategoricalQNetwork object at 0x0000021D34EE2D90>: `inputs` and `input_tensor_spec` do not have matching structures:
  .
vs.
  {'observation': ., 'valid_actions': .}
Values:
  [[  1   1   1  30  30   5   4   0   0   0   1   0   0   0   0   0   0   1
    0   0   7   9   5   7   7   7   2   3   2   7   7   7   2   3   2 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99
  -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99 -99]]
vs.
  {'observation': BoundedTensorSpec(shape=(85,), dtype=tf.int32, name='observation', minimum=array(-100), maximum=array(100)), 'valid_actions': TensorSpec(shape=(91,), dtype=tf.bool, name='valid_actions')}.

I’m not that experienced in RL, so I don’t want to imply that there is an issue in the source code, but I don’t really get how to get rid of this error since categorical_q_policy.py purposely remove a dimension.

Thank you in advance for your help !