How to use sample_weights with 3D medical data, without model.fit(x=tf.data.Dataset) causing an error like can't squeeze the last dim

env:

python = 3.8.12
tensorflow = 2.6.0
keras = 2.6.0

so the problem is that I am trying to train highly unbalanced data, so I tried to use sample_weights as part of model.fit(), but I always get the same error:

ValueError: Can not squeeze dim[4], expected a dimension of 1, got 4 for '{{node categorical_crossentropy/weighted_loss/Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[-1]](Cast)' with input shapes: [?,48,48,80,4].

so this is the shape of the data, where the y_s were converted using tf.keras.utils.to_categorical, where num_classes = 4 :

x_train (54, 48, 48, 80)
y_train (54, 48, 48, 80, 4)
x_test (18, 48, 48, 80)
y_test (18, 48, 48, 80, 4)
x_val (18, 48, 48, 80)
y_val (18, 48, 48, 80, 4)

the architecture is U-NET:

inputs = Input((number_of_layers, height, width, 1))
c1 = Conv3D(filters=16, kernel_size=3, activation=‘relu’, kernel_initializer=‘he_normal’, padding=‘same’)(inputs)
c1 = Dropout(0.1)(c1)
c1 = Conv3D(16, kernel_size=3, activation=‘relu’, kernel_initializer=‘he_normal’, padding=‘same’)(c1)
p1 = MaxPooling3D(pool_size=2)(c1)

outputs = Conv3D(num_classes, kernel_size=1, activation=‘softmax’)(u9)
model = Model(inputs=[inputs], outputs=[outputs])

regarding the compile part, it’s like the following:

model.compile(optimizer=‘adam’, loss=‘categorical_crossentropy’, metrics=[‘accuracy’], sample_weight_mode=“temporal”)

NOTE: I’m not using metrics=[‘accuracy’] for evaluation, I’m using some IOU

The problem comes here, when I am using:

from sklearn.utils.class_weight import compute_sample_weight
weights = compute_sample_weight(class_weight=‘balanced’, y=y_train.flatten())
weights = weights.reshape(y_train.shape)
weights.shape # => (54, 48, 48, 80, 4) (same as y_train)

so till here it’s working, without any errors, but when I added weights to the following dataset:

tf_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train, weights)).batch(4)

and after that I tried to run model.fit:

model.fit(x=tf_ds, verbose=1, epochs=5, validation_data=(x_val, y_val))

I got the following error:

ValueError: Can not squeeze dim[4], expected a dimension of 1, got 4 for ‘{{node categorical_crossentropy/weighted_loss/Squeeze}} = SqueezeT=DT_FLOAT, squeeze_dims=[-1]’ with input shapes: [?,48,48,80,4].

Any ideas, how to solve this ?

My understanding: this model populates a 4-class one-hot output for each voxel in the input. I can see how this would be an unbalanced set. Usually, “unbalanced set” means at the sample level, rather than the low-level data. That is, “unbalanced set” means “too many dogs” when classifying “dogs v.s. cats” rather than “this pixel is fur or not fur”: you are trying to create a weightset for “fur or not fur” in a picture of a cat or dog, with a separate weightset for each pixel.

The implementation of class weights is biased towards the “dogs v.s. cats” use case. It is not intended for the “fur or not fur” case.

If per-voxel classification really is what you are trying for, you can split out the output layer into 48x48x80 separate one-hot layers, and apply the unique weightset at each output. This may take a little longer for the model to compile, but it should run at exactly the same speed.