Lattice model monotonicity

Hi All,

I am trying to use tensorflow lattice in a model and enforce monotonicity in for a set of features. I am able to do this for a binary classification model using the below code structure

from keras.layers import Input
from keras.models import Model
from keras.utils import plot_model
import tensorflow as tf
import tensorflow_lattice as tfl

def create_model(layer1_nodes, layer2_nodes, dense1_nodes, dense2_nodes, inputShape1, inputShape2):
    input_lattice = Input(shape=inputShape1)
    input_other = Input(shape=inputShape2)

   # impose monotonicity for first feature
   input_lattice_concat = tfl.layers.PWLCalibration(input_keypoints=np.linspace(0, 1, num=20),
                             # convexity='convex',
                             )(input_lattice[:, :, 0])                             


    input_all = tf.keras.layers.Concatenate(axis=2)([input_lattice_concat, input_other])
    lstm1 = LSTM(layer1_nodes, return_sequences=True)(input_all)
    lstm2 = LSTM(layer2_nodes, return_sequences=True)(lstm1)
    flatten = Flatten()(lstm2)
    dense1 = Dense(dense1_nodes)(flatten)
    dense2 = Dense(dense2_nodes)(dense1)

    output = Dense(3, activation="softmax")(dense2)
    loss_fn = BinaryCrossentropy(from_logits=False)

    model = Model(inputs=[input_lattice, input_other], outputs=output)
    opt = Adam(learning_rate=0.001)
    loss_fn = SparseCategoricalCrossentropy(from_logits=False)
    model.compile(optimizer=opt, loss=loss_fn, metrics=[SparseCategoricalAccuracy()])
    return model

However, I am not sure how to do this in a multi class classification model. Any suggestions on how to do this?