Different input in the same batch generate same output

TensorFlow v2.9. I am using on-device training. The SavedModel cannot leverage high-level APIs(e.g. predict or fit). The prediction is done using model(x)
as it is explained here: On-Device Training with TensorFlow Lite

I don’t know why but I get same result for different input in a batch

E.g. Batch size is 3, and model(x) accepts (3, 4, 15, 15) as input, and one of the output is (3, 1, 225).

As it is listed below, all the three vectors(1,225) are exact the same in the output tensor while their input are different in a batch.

[
	[
		[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
	],
	[
		[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
	],
	[
		[-20.500122, -20.500196, -16.388021, -20.500189, -13.888604, -20.500208, -20.500103, -13.725816, -16.14115, -15.523373, -16.094854, -15.536175, -13.494872, -20.500164, -16.729692, -17.314562, -9.923043, -20.500137, -13.227316, -19.462494, -8.832517, -11.005514, -16.657751, -20.500229, -19.104895, -17.969429, -16.826006, -18.479736, -11.35681, -20.50018, -17.686893, -15.8137665, -20.500158, -20.498934, -11.30343, -12.114782, -6.9864135, -16.129002, -11.758956, -13.793568, -10.100338, -18.394066, -7.8771715, -18.867481, -13.54011, -20.500141, -18.142273, -13.827344, -12.14585, -8.751808, -7.360826, -7.8197165, -8.190978, -7.9918194, -7.1475286, -10.866553, -13.463445, -12.561472, -17.644833, -20.499897, -15.04738, -15.1495285, -15.757288, -10.316235, -6.4681287, -6.771983, -6.2083254, -5.169312, -5.9851274, -7.3863406, -5.7047515, -11.461843, -19.462492, -20.499823, -16.014748, -19.572166, -10.054104, -9.654353, -6.9895654, -6.523039, -3.4712281, -4.010914, -3.058044, -5.203539, -4.562346, -7.3472414, -8.2306795, -14.15948, -16.442978, -15.1097, -20.499994, -16.006512, -13.285485, -9.599341, -5.576161, -5.10128, -2.1091957, -2.6103199, -2.3030841, -4.3452697, -5.1566353, -6.7773423, -13.5079155, -18.91643, -20.49996, -20.50012, -20.500032, -15.034921, -7.0785294, -6.62519, -2.6741242, -3.3764887, -3.2719333, -3.4223785, -3.1113718, -6.607987, -6.7852387, -9.567825, -17.231964, -18.361439, -15.199417, -20.500113, -8.907006, -8.894981, -4.4610567, -5.3974047, -3.1986039, -3.308056, -2.5260184, -4.416704, -5.5637026, -8.839353, -7.404949, -18.09958, -20.499996, -20.500063, -12.94954, -17.1081, -7.8807735, -6.0368576, -4.0000243, -4.983799, -3.7624922, -3.9401622, -5.351621, -7.3347793, -6.7273192, -16.521574, -10.310918, -18.213472, -18.239689, -20.49987, -13.403644, -10.768933, -6.169673, -6.226465, -4.851883, -3.5755277, -5.7955694, -7.59566, -6.5219584, -15.287647, -9.992104, -20.49974, -11.737182, -20.500032, -13.7056465, -11.700055, -11.151376, -12.240701, -6.9801717, -9.907572, -9.89772, -7.7714005, -7.599248, -14.2966175, -10.805019, -14.946489, -15.138906, -20.49991, -16.84454, -20.500303, -15.745817, -11.974067, -14.362624, -13.677492, -6.8857694, -10.488706, -9.6858, -15.690493, -13.776093, -17.350763, -13.82417, -20.500122, -16.799477, -11.256063, -16.112524, -20.500021, -16.107948, -11.349038, -12.018146, -20.500145, -15.021783, -20.500141, -14.088732, -19.462494, -16.841585, -17.49845, -15.664743, -18.375904, -20.500162, -17.897068, -20.50004, -13.704247, -15.333616, -20.500124, -14.740182, -12.495611, -20.500069, -20.50013, -17.074047, -13.579008, -16.136011, -20.500244, -11.993184]
	]
]

(4, 15, 15) sub-tensors are different in the input tensor.

[[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 1 1 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]


 [[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
   [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]]


 [[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

  [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
   [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]]

Please what is the problem here? Full source code below.


def create_model(board_width, board_height):

    class RenjuModel(tf.Module):
        def __init__(self):
            l2_penalty_beta = 1e-4

            # Define the tensorflow neural network
            # 1. Input:
            self.inputs = tf.keras.Input( shape=(4, board_height, board_width), dtype=tf.dtypes.float32, name="input")
            self.transposed_inputs = tf.keras.layers.Lambda( lambda x: tf.transpose(x, [0, 2, 3, 1]) )(self.inputs)

            # 2. Common Networks Layers
            self.conv1 = tf.keras.layers.Conv2D( name="conv1",
                filters=32,
                kernel_size=(3, 3),
                padding="same",
                data_format="channels_last",
                activation=tf.keras.activations.relu,
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.transposed_inputs)

            self.conv2 = tf.keras.layers.Conv2D( name="conv2", 
                filters=64, 
                kernel_size=(3, 3), 
                padding="same", 
                data_format="channels_last", 
                activation=tf.keras.activations.relu,
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.conv1)

            self.conv3 = tf.keras.layers.Conv2D( name="conv3",
                filters=128,
                kernel_size=(3, 3),
                padding="same",
                data_format="channels_last",
                activation=tf.keras.activations.relu,
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.conv2)

            # 3-1 Action Networks
            self.action_conv = tf.keras.layers.Conv2D( name="action_conv",
                filters=4,
                kernel_size=(1, 1),
                padding="same",
                data_format="channels_last",
                activation=tf.keras.activations.relu,
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.conv3)

            # flatten tensor
            self.action_conv_flat = tf.keras.layers.Reshape( (-1, 4 * board_height * board_width), name="action_conv_flat" 
            )(self.action_conv)

            # 3-2 Full connected layer, the output is the log probability of moves
            # on each slot on the board
            self.action_fc = tf.keras.layers.Dense( board_height * board_width,
                activation=tf.nn.log_softmax,
                name="action_fc",
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.action_conv_flat)

            # 4 Evaluation Networks
            self.evaluation_conv = tf.keras.layers.Conv2D( name="evaluation_conv",
                filters=2,
                kernel_size=(1, 1),
                padding="same",
                data_format="channels_last",
                activation=tf.keras.activations.relu,
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.conv3)

            self.evaluation_conv_flat = tf.keras.layers.Reshape( (-1, 2 * board_height * board_width),
                name="evaluation_conv_flat" 
                )(self.evaluation_conv)

            self.evaluation_fc1 = tf.keras.layers.Dense( 64,
                activation=tf.keras.activations.relu,
                name="evaluation_fc1",
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.evaluation_conv_flat)

            self.evaluation_fc2 = tf.keras.layers.Dense( 1, 
                activation=tf.keras.activations.tanh,
                name="evaluation_fc2",
                kernel_regularizer=tf.keras.regularizers.L2(l2_penalty_beta)
                )(self.evaluation_fc1)

            self.model = tf.keras.Model(inputs=self.inputs, outputs=[self.action_fc, self.evaluation_fc2], name="renju_model")
            self.model.summary()
 
            self.lr = tf.Variable(0.002, trainable=False, dtype=tf.dtypes.float32)

            self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = self.lr),
                    loss=[self.action_loss, tf.keras.losses.MeanSquaredError()],
                    metrics=['accuracy'])


        @tf.function(input_signature=[ tf.TensorSpec([None, 1, board_height * board_width], tf.float32),
            tf.TensorSpec([None, 1, board_height * board_width], tf.float32)
        ])
        def action_loss(self, labels, predictions):
            tf.print(labels, summarize=-1)
            tf.print(predictions, summarize=-1)
            # labels are probabilities; predictions are logits
            return tf.negative(tf.reduce_mean(
                        tf.reduce_sum(tf.multiply(labels, predictions), 2)))
           

        @tf.function(input_signature=[
            tf.TensorSpec([None, 4, board_height, board_width], tf.float32),
        ])
        def predict(self, state_batch):
            if tf.shape(state_batch)[0] > 1:
                tf.print(state_batch, summarize=-1)
            x = self.model(state_batch)
            if tf.shape(state_batch)[0] > 1:
                tf.print(x, summarize=-1)
            return x

        @tf.function(input_signature=[tf.TensorSpec(shape=[None, 4, board_height, board_width],  dtype=tf.float32), 
                                  tf.TensorSpec(shape=[None, 1, board_height * board_width],  dtype=tf.float32),
                                  tf.TensorSpec(shape=[],  dtype=tf.float32),
                                  tf.TensorSpec(shape=[1],  dtype=tf.float32) ])
        def train(self, state_batch, mcts_probs, winner_batch, lr):
            

            self.lr.assign(tf.gather(lr, 0))
            with tf.GradientTape() as tape:
                predictions = self.model(state_batch, training=True)  # Forward pass
                # the loss function is configured in `compile()`
                loss = self.model.compiled_loss([mcts_probs, winner_batch], predictions, regularization_losses=self.model.losses)
 
            gradients = tape.gradient(loss, self.model.trainable_variables)
            self.model.optimizer.apply_gradients(
                zip(gradients, self.model.trainable_variables))

            entropy = tf.negative(tf.reduce_mean(
                tf.reduce_sum(tf.exp(predictions[0][0]) * predictions[0][0], 1)))

            return (loss, entropy)

        

        @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
        def save(self, checkpoint_path):
            tensor_names = [weight.name for weight in self.model.weights]
            tensors_to_save = [weight.read_value() for weight in self.model.weights]
            tf.raw_ops.Save(
                filename=checkpoint_path, tensor_names=tensor_names,
                data=tensors_to_save, name='save')
            return checkpoint_path

        @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
        def restore(self, checkpoint_path):
            restored_tensors = {}
            for var in self.model.weights:
                restored = tf.raw_ops.Restore( file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype, name='restore')
                var.assign(restored)
                restored_tensors[var.name] = restored
            return checkpoint_path

        @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
        def random_choose_with_dirichlet_noice(self, probs):
            concentration = 0.3*tf.ones(tf.size(probs))
            dist = tfp.distributions.Dirichlet(concentration)
            p = 0.75*probs + 0.25*dist.sample(1)[0]
            samples = tf.random.categorical(tf.math.log([p]), 1)
            return samples[0] # selected index


    return RenjuModel()


model = create_model( 15, 15)

#Saving the model, explictly adding the concrete functions as signatures
model.model.save('renju_15x15_model', 
        save_format='tf', 
        signatures={
            'predict': model.predict.get_concrete_function(), 
            'train' : model.train.get_concrete_function(), 
            'save' : model.save.get_concrete_function(),
            'restore' : model.restore.get_concrete_function(),
            'random_choose_with_dirichlet_noice' : model.random_choose_with_dirichlet_noice.get_concrete_function() 
        })

I also noticed that, when loss function is called. the predictions are the same even labels are different.

The following is labels in action_loss copied from log

[[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

 [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2.97808528e-10 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]

And here is predictions in action_loss copied from log

[[[-14.9128876 -18.2595577 -18.2102909 -14.6844273 -12.9892006 -18.2595615 -14.8576031 -18.1647701 -17.8895378 -18.200613 -16.1771622 -18.2007656 -17.6789341 -18.2302532 -18.2136555 -18.2191982 -15.3528461 -18.2595577 -18.1463146 -18.2443447 -13.4267159 -17.9375782 -13.1170254 -18.2595615 -15.4441738 -16.3065681 -15.7941322 -18.2313137 -14.2309837 -15.8361015 -15.5098228 -18.2041645 -16.0818138 -18.2312984 -15.1668968 -14.484169 -9.0487175 -18.2076473 -17.7939262 -14.9300575 -14.5430374 -18.1603451 -15.5655346 -15.8662224 -18.1585827 -18.2595577 -18.2274876 -17.9433651 -18.0791912 -17.2240448 -5.82286024 -12.4535313 -7.39751387 -12.1833401 -11.5815334 -17.9126835 -16.6792641 -14.0191441 -18.2223377 -18.2595539 -18.1939259 -18.1954784 -18.2034779 -11.0206509 -6.06101179 -5.52775908 -4.94801378 -6.46023893 -4.56483316 -4.82473326 -5.49228239 -13.8817501 -15.5314484 -14.8815937 -14.9936275 -18.2459621 -17.7217 -16.2399197 -7.76165915 -6.78857756 -5.38073587 -5.3190856 -4.74219465 -3.27214193 -3.43368292 -7.8375411 -7.32885885 -13.8216724 -13.2183819 -18.194891 -18.2595539 -18.2063 -9.05840492 -17.5760021 -6.32114935 -4.44361448 -2.77819204 -2.13587618 -2.46561384 -5.80532598 -5.21131754 -8.65453339 -12.7332077 -18.2367496 -18.2595539 -14.5108356 -18.2305813 -15.8598061 -9.3568306 -5.74266 -4.37345362 -3.24221468 -2.65785074 -3.87472391 -4.98595476 -4.87942362 -12.8089809 -17.5647907 -18.2183933 -14.1898327 -18.1962223 -18.2301655 -6.70136404 -8.67263603 -5.47746801 -3.46449232 -2.72609091 -3.02782202 -1.90738916 -3.56713152 -5.51725912 -10.3449841 -10.1139374 -18.2270069 -18.2595539 -18.2339573 -12.7653484 -18.1167927 -6.23910666 -6.35667 -4.42290449 -4.16379118 -5.01942873 -4.13271 -5.01809454 -6.85058355 -10.1382027 -15.6842995 -15.1318359 -18.2282581 -18.2285519 -18.2595539 -14.4078712 -17.89398 -7.81883955 -6.51721334 -5.01675558 -6.04885626 -4.90908194 -5.97763777 -6.7340064 -18.1974583 -17.7035789 -14.9518013 -13.6722279 -18.2449627 -18.164156 -18.0351849 -16.3881531 -11.707468 -6.16785192 -13.9624691 -11.5368443 -13.3123302 -6.41238356 -18.1799755 -17.9010029 -17.7942257 -17.9693356 -18.2322178 -15.3313866 -18.2595615 -18.2033291 -14.2638874 -18.1814213 -18.1632252 -6.92246962 -17.8341541 -15.7046471 -18.2026615 -18.1663494 -13.7783928 -15.1686764 -18.2595577 -13.934206 -17.9775791 -18.2074642 -18.2595577 -18.0839062 -17.990942 -18.0677891 -18.2320728 -18.1935482 -18.2595577 -17.8355331 -18.1885815 -18.214716 -18.2209225 -18.2023487 -18.2300549 -13.7487602 -18.2248821 -18.2595577 -18.1641026 -18.1980801 -18.2595577 -18.1888523 -18.1060772 -18.2595577 -18.2595577 -18.2169056 -18.1599483 -18.2077236 -18.2595615 -18.0654659]]

 [[-14.9128876 -18.2595577 -18.2102909 -14.6844273 -12.9892006 -18.2595615 -14.8576031 -18.1647701 -17.8895378 -18.200613 -16.1771622 -18.2007656 -17.6789341 -18.2302532 -18.2136555 -18.2191982 -15.3528461 -18.2595577 -18.1463146 -18.2443447 -13.4267159 -17.9375782 -13.1170254 -18.2595615 -15.4441738 -16.3065681 -15.7941322 -18.2313137 -14.2309837 -15.8361015 -15.5098228 -18.2041645 -16.0818138 -18.2312984 -15.1668968 -14.484169 -9.0487175 -18.2076473 -17.7939262 -14.9300575 -14.5430374 -18.1603451 -15.5655346 -15.8662224 -18.1585827 -18.2595577 -18.2274876 -17.9433651 -18.0791912 -17.2240448 -5.82286024 -12.4535313 -7.39751387 -12.1833401 -11.5815334 -17.9126835 -16.6792641 -14.0191441 -18.2223377 -18.2595539 -18.1939259 -18.1954784 -18.2034779 -11.0206509 -6.06101179 -5.52775908 -4.94801378 -6.46023893 -4.56483316 -4.82473326 -5.49228239 -13.8817501 -15.5314484 -14.8815937 -14.9936275 -18.2459621 -17.7217 -16.2399197 -7.76165915 -6.78857756 -5.38073587 -5.3190856 -4.74219465 -3.27214193 -3.43368292 -7.8375411 -7.32885885 -13.8216724 -13.2183819 -18.194891 -18.2595539 -18.2063 -9.05840492 -17.5760021 -6.32114935 -4.44361448 -2.77819204 -2.13587618 -2.46561384 -5.80532598 -5.21131754 -8.65453339 -12.7332077 -18.2367496 -18.2595539 -14.5108356 -18.2305813 -15.8598061 -9.3568306 -5.74266 -4.37345362 -3.24221468 -2.65785074 -3.87472391 -4.98595476 -4.87942362 -12.8089809 -17.5647907 -18.2183933 -14.1898327 -18.1962223 -18.2301655 -6.70136404 -8.67263603 -5.47746801 -3.46449232 -2.72609091 -3.02782202 -1.90738916 -3.56713152 -5.51725912 -10.3449841 -10.1139374 -18.2270069 -18.2595539 -18.2339573 -12.7653484 -18.1167927 -6.23910666 -6.35667 -4.42290449 -4.16379118 -5.01942873 -4.13271 -5.01809454 -6.85058355 -10.1382027 -15.6842995 -15.1318359 -18.2282581 -18.2285519 -18.2595539 -14.4078712 -17.89398 -7.81883955 -6.51721334 -5.01675558 -6.04885626 -4.90908194 -5.97763777 -6.7340064 -18.1974583 -17.7035789 -14.9518013 -13.6722279 -18.2449627 -18.164156 -18.0351849 -16.3881531 -11.707468 -6.16785192 -13.9624691 -11.5368443 -13.3123302 -6.41238356 -18.1799755 -17.9010029 -17.7942257 -17.9693356 -18.2322178 -15.3313866 -18.2595615 -18.2033291 -14.2638874 -18.1814213 -18.1632252 -6.92246962 -17.8341541 -15.7046471 -18.2026615 -18.1663494 -13.7783928 -15.1686764 -18.2595577 -13.934206 -17.9775791 -18.2074642 -18.2595577 -18.0839062 -17.990942 -18.0677891 -18.2320728 -18.1935482 -18.2595577 -17.8355331 -18.1885815 -18.214716 -18.2209225 -18.2023487 -18.2300549 -13.7487602 -18.2248821 -18.2595577 -18.1641026 -18.1980801 -18.2595577 -18.1888523 -18.1060772 -18.2595577 -18.2595577 -18.2169056 -18.1599483 -18.2077236 -18.2595615 -18.0654659]]]

reason found, it is because an invalid checkpoint is restored.

1 Like