Contextual bandits - understanding the examples

Hi all,

I’m trying to understand the tutorials on Multi-Armed Bandit:

  1. Tutorial sui banditi multi-armati in TF-Agents  |  TensorFlow Agents (Section “A Real Contextual Bandit Example”), and
  2. A Tutorial on Multi-Armed Bandits with Per-Arm Features  |  TensorFlow Agents

If I wanted to make a simulation of a contextual bandit for displaying online ads and use the click through rate (CTR=clicks/impressions) to separate good from bad ads, how would I define the true values of the CTRs in the context of the two tutorials? In the following I explain my reasoning with a simple code attached for cases 1) and 2). However, feel free to skip that and go straight to giving me the correct way to model this if you wish :slight_smile: All input are greatly appreciated.

For 1), I wrote the following code, where the true values of the three ads are
arm0_param = [0.01, 0.02, 0.03, 0.07] # CTRs for ad 0
arm1_param = [0.02, 0.06, 0.011, 0.06] # CTRs for ad 1
arm2_param = [0.06, 0.012, 0.02, 0.03] # CTRs for ad 2

with age group as context: age_groups = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
So for a viewer in the age group 0, the true value of ad 1 is 0.02 (2% click rate).
This code seems to kind of work, i.e. output the estimated CTRs that are equal to my true CTRs, within the uncertainty, and sampling like the Thompson sampling agent should. But it seems that the concept is very simple and that the posteriors for the rewards are completely separate from each other (that is, one model per true CTR, or equivalently, one simple multi-armed bandit per context). This works, but requires a lot of data, and I imagine there are much more efficient ways of modelling this kind of system.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
from tf_agents.bandits.agents import linear_thompson_sampling_agent
from tf_agents.bandits.environments import stationary_stochastic_py_environment as sspe
from tf_agents.bandits.metrics import tf_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
import tensorflow as tf
from tf_agents.drivers import driver
from tf_agents.environments import tf_py_environment
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

import numpy as np
import matplotlib.pyplot as plt

# Define the age groups as one-hot encoded vectors
age_groups = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
# Define the CTRs for the three ads
batch_size = 1 # @param
arm0_param = [0.01, 0.02, 0.03, 0.07] # CTRs for ad 0
arm1_param = [0.02, 0.06, 0.01, 0.06] # CTRs for ad 1
arm2_param = [0.06, 0.01, 0.02, 0.03] # CTRs for ad 2
arms_param = [arm0_param, arm1_param, arm2_param]

dict_mu = {}
dict_mu[0], dict_mu[1], dict_mu[2] = [0,0,0,0], [0,0,0,0], [0,0,0,0]

actions = {}
actions[0] = [0,0,0,0]
actions[1] = [0,0,0,0]
actions[2] = [0,0,0,0]

def find_position(a, el):
  for i in range(len(a)):
    if np.array_equal(a[i], el):
      return i
  return -1

def context_sampling_fn(batch_size):
  """Contexts representing age groups."""
  def _context_sampling_fn():
    # Randomly select an age group for each instance in the batch to simulate someone viewing the ad
    indices = np.random.randint(0, 4, batch_size)
    tmp = age_groups[indices].astype(np.float32)
    return tmp
  return _context_sampling_fn

class LinearNormalReward(object):
  def __init__(self, theta):
    self.theta = theta
  def __call__(self, x):
    mu = np.dot(x, self.theta)
    reward = np.random.binomial(1, mu) # using binomial as reward is 1 or 0 (click or no click), not normal
    sel_arm = find_position(arms_param, self.theta)
    sel_context = find_position(age_groups, x)
    dict_mu[sel_arm][sel_context] += reward
    actions[sel_arm][sel_context] += 1
    return reward

arm0_reward_fn = LinearNormalReward(arm0_param)
arm1_reward_fn = LinearNormalReward(arm1_param)
arm2_reward_fn = LinearNormalReward(arm2_param)

environment = tf_py_environment.TFPyEnvironment(
    sspe.StationaryStochasticPyEnvironment(
        context_sampling_fn(batch_size),
        [arm0_reward_fn, arm1_reward_fn, arm2_reward_fn],
        batch_size=batch_size))

observation_spec = tensor_spec.TensorSpec([4], tf.float32)
time_step_spec = ts.time_step_spec(observation_spec)
action_spec = tensor_spec.BoundedTensorSpec(
    dtype=tf.int32, shape=(), minimum=0, maximum=2)

agent = linear_thompson_sampling_agent.LinearThompsonSamplingAgent(time_step_spec=time_step_spec, action_spec=action_spec)

def compute_optimal_reward(observation):
  print("compute_optimal_reward")
  expected_reward_for_arms = [
      tf.linalg.matvec(observation, tf.cast(arm0_param, dtype=tf.float32)),
      tf.linalg.matvec(observation, tf.cast(arm1_param, dtype=tf.float32)),
      tf.linalg.matvec(observation, tf.cast(arm2_param, dtype=tf.float32))]
  optimal_action_reward = tf.reduce_max(expected_reward_for_arms, axis=0)
  return optimal_action_reward

regret_metric = tf_metrics.RegretMetric(compute_optimal_reward)

num_iterations = 5000 # @param
steps_per_loop = 1 # @param

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.policy.trajectory_spec,
    batch_size=batch_size,
    max_length=steps_per_loop)

observers = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env=environment,
    policy=agent.collect_policy,
    num_steps=steps_per_loop * batch_size,
    observers=observers)

policy = agent.collect_policy

for i in range(num_iterations):
  time_step, policy_state = driver.run()
  loss_info = agent.train(replay_buffer.gather_all())
  replay_buffer.clear()

# These will output the measured CTRs of the ads
print(np.asarray(dict_mu[0])/np.asarray(actions[0]))
print(np.asarray(dict_mu[1])/np.asarray(actions[1]))
print(np.asarray(dict_mu[2])/np.asarray(actions[2]))

# This will show how many times each action was selected
print("Actions")
print(actions)

For 2) I wrote the following code, with 4 ads (or ad categories, each can contain several ads) with a true base CTR value of
true_category = np.array([0.1, 0.15, 0.2, 0.25])
The context is gender (not age group as in 1) ) and the gender influences the CTR like this:
true_gender_influence = np.array([0.01, -0.02])

Thus if the global context (gender) is e.g. [1, 0] and the per arm context (category) is [0, 1, 0, 0], then the true CTR will be 0.01+0.15=0.16 (from the dot product of [1, 0, 0, 1, 0, 0] and [0.01, -0.2, 0.1, 0.15, 0.2, 0.25] in linear_normal_reward_fn).

I have a feeling I’ve misunderstood something, as this is a very simple and not very useful model - gender [1, 0] will alwayd have a bit higher probability of clicking an ad, and gender [0,1] will always have a bit lower probability of clicking an ad.

import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
import functools
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_agents.bandits.agents import lin_ucb_agent, linear_thompson_sampling_agent

from tf_agents.bandits.environments import stationary_stochastic_per_arm_py_environment as p_a_env
from tf_agents.bandits.metrics import tf_metrics as tf_bandit_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts

nest = tf.nest

# The variance of the Gaussian distribution that generates the rewards.
VARIANCE = 0.1

n_ads = 4
n_genders = 2
BATCH_SIZE = 1

true_category = np.array([0.1, 0.15, 0.2, 0.25]) 
true_gender_influence = np.array([0.01, -0.02]) 
# The categories for the ads, represented as one-hot vectors
ad_categories = [
  [1, 0, 0, 0],  # beds
  [0, 1, 0, 0],  # dressers
  [0, 0, 1, 0], # mirrors
  [0, 0, 0, 1], # hot dogs
]

reward_param = np.concatenate([true_gender_influence,true_category])

# Reshape true_category to a column vector and true_gender_influence to a row vector
true_category2 = true_category[:, np.newaxis]
true_gender_influence2 = true_gender_influence[np.newaxis, :]
# Use broadcasting to create a 2D array for all combinations of true_category and true_gender_influence
# This is what can me measured in real life
combinations = true_category2 + true_gender_influence2

a_mu = np.zeros((4, 2))
a_actions = np.zeros((4, 2))

def find_position(a, el):
  for i in range(len(a)):
    if np.array_equal(a[i], el):
      return i
  return -1

def global_context_sampling_fn():
  """This function generates a single global observation vector."""
  gender = np.random.randint(0, 2)  # Example code, replace with your own
  gender_one_hot = tf.one_hot(gender, 2)
  return np.asarray(gender_one_hot)

# The purpose of this function is just to output the 4 different global contexts, to allow for the sampling of the posterior by the Thompson sampling agent.
# The current ad index
current_ad_index = 0
def per_arm_context_sampling_fn():
  """"This function generates a single per-arm observation vector."""
  global current_ad_index
  # Get the category for the current ad
  category = ad_categories[current_ad_index]
  # Update the current ad index for the next call
  current_ad_index = (current_ad_index + 1) % len(ad_categories)
  return np.array(category, dtype=np.float32)

def linear_normal_reward_fn(x):
  """This function generates a reward from the concatenated global and per-arm observations."""
  pos_gender = find_position(x[0:n_genders],1)
  pos_ctr = find_position(x[n_genders:],1)
  mu = np.dot(x, reward_param)
  tmp =  np.random.normal(mu, VARIANCE)
  a_actions[pos_ctr][pos_gender] += 1
  # Binomial can be used in stead of normal
  # mu = np.dot(x,np.concatenate([true_gender_influence,true_category]))
  # mu = 1 / (1 + np.exp(-mu))  # Possibly apply sigmoid function to ensure mu is between 0 and 1
  # tmp = np.random.binomial(1, mu)
  a_mu[pos_ctr][pos_gender] += tmp
  return tmp

per_arm_py_env = p_a_env.StationaryStochasticPerArmPyEnvironment(
    global_context_sampling_fn,
    per_arm_context_sampling_fn,
    n_ads,
    linear_normal_reward_fn,
    batch_size=BATCH_SIZE
)
current_ad_index = 0  # Reset current_ad_index to 0

per_arm_tf_env = tf_py_environment.TFPyEnvironment(per_arm_py_env)
observation_spec = per_arm_tf_env.observation_spec()
time_step_spec = ts.time_step_spec(observation_spec)
action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=n_ads - 1)
agent = linear_thompson_sampling_agent.LinearThompsonSamplingAgent(time_step_spec=time_step_spec, action_spec=action_spec, accepts_per_arm_features=True)

def _all_rewards(observation, hidden_param):
  """Outputs rewards for all actions, given an observation."""
  hidden_param = tf.cast(hidden_param, dtype=tf.float32)
  global_obs = observation['global']
  per_arm_obs = observation['per_arm']
  num_actions = tf.shape(per_arm_obs)[1]
  tiled_global = tf.tile(tf.expand_dims(global_obs, axis=1), [1, num_actions, 1])
  concatenated = tf.concat([tiled_global, per_arm_obs], axis=-1)
  rewards = tf.linalg.matvec(concatenated, hidden_param)
  return rewards

def optimal_reward(observation):
  """Outputs the maximum expected reward for every element in the batch."""
  return tf.reduce_max(_all_rewards(observation, reward_param), axis=1)

regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward)

num_iterations = 200 # @param
steps_per_loop = 1 # @param

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.policy.trajectory_spec,
    batch_size=BATCH_SIZE,
    max_length=steps_per_loop)

observers = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env=per_arm_tf_env,
    policy=agent.collect_policy,
    num_steps=steps_per_loop * BATCH_SIZE,
    observers=observers)
regret_values = []
policy = agent.collect_policy

for i in range(num_iterations):
  time_step, policy_state = driver.run()
  loss_info = agent.train(replay_buffer.gather_all())
  replay_buffer.clear()
  
print("True values")
print(combinations)
print("Results from simulated measurement")
print("Mean rewards (should be similar to true values)")
print(a_mu/a_actions)
print("Actions")
print(a_actions)