LSTM in tensorflow.js game - what am I doing wrong?

Hi everyone,

I’m very new to neural networks and just trying to figure out the right approach for my project and why my current code is not working as planned. I apologize in advance for my stupid questions.

I am trying to use tensorflow.js in a game, where I have 18 inputs, such as the x- and y-positions of the enemies, and 3 outputs: the cosine and sine of an angle between -180 and 180 (as suggested here: circular statistics - Encoding Angle Data for Neural Network - Cross Validated) used for directions and an output that controls moving or not moving (-1: do not move, 1: move).

Because the cos and sin results are between -1 and 1, I am using -1 and 1 for the third output as well and tanh as the activation function. I pass in 36 inputs (previous and current frame) to predict the direction to go in in the current frame and whether to move or not. The net then returns 6 outputs (3 for each frame), so I take the last 3 and put them into the player controller. I use returnSequences: true for the last lstm layer because I cannot seem to figure out how to make it output only the last 3 outputs without an error.

When I train the model using game data and then try to use it in-game, the outputs all center around certain values and barely vary, which leads to the player never moving at all or always moving in one direction. EDIT: This appears to be an issue of the amount of training/data. I’m beginning to see results with more training.

However, I still want to ask: Is my current method valid? What should I try to improve my approach? Training on my laptop takes very long, so I would like to improve training performance.
I also cannot find the memory leak - one tensor remains at the end. Can you find it? I think it is in “compile,” but using tf.tidy() or startScope/endScope did not help.

See my code below (just a test environment, not the entire game):

</div>
    <div>
        <span>Output after training: </span>
    </div>
    <div>
        <span id = "Output">0</span>
    </div>
</div>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>

<script>
    let net = tf.sequential();
    console.log("number of tensors in memory before adding layers:", tf.memory().numTensors)
    net.add(tf.layers.lstm({units: 18, inputShape: [null, 18], activation: "tanh", returnSequences: true}));
    net.add(tf.layers.lstm({units: 18, activation: "tanh", returnSequences: true}));
    net.add(tf.layers.lstm({units: 18, activation: "tanh", returnSequences: true}));
    net.add(tf.layers.dense({units: 3, activation: "tanh"}));
    console.log("number of tensors in memory after adding layers:", tf.memory().numTensors)
    net.compile({loss: "meanSquaredError", optimizer: tf.train.sgd(0.3)});
    console.log("number of tensors in memory after compiling model:", tf.memory().numTensors)

    
    //Array of 360 items; one timestep is supposed to consist of 18 items, and the total length of the array should be variable
    let totalInputData = [0.5,0.5,1.0265625847740325,0.4968331195707686,-0.027321307508046434,0.8432027740193291,-0.026595301826361762,0.43737337534322807,0.86542210185598,-0.02739700386481555,-0.027123958327189335,0.21580715437649517,1.026609147819722,0.5748529301190116,1.026614258178386,0.5789125974595326,1.02669943068863,0.3698218427429599,0.5,0.5,1.0218751695480652,0.49686131087028745,-0.02339261501609287,0.8406458165394973,-0.021940603652723524,0.43792694662348824,0.8627524432575345,-0.0235440077296311,-0.02299791665437867,0.21803166249386177,1.0219682956394436,0.5741932731696758,1.0219785163567718,0.5782179363814965,1.02214886137726,0.3709465539025108,0.5,0.5,1.0171877543220977,0.4968895021698064,-0.019463922524139303,0.8380888590596655,-0.01728590547908529,0.4384805179037484,0.860082784659089,-0.01969101159444665,-0.018871874981568004,0.22025617061122835,1.0173274434591655,0.5735336162203402,1.0173427745351575,0.5775232753034605,1.01759829206589,0.3720712650620616,0.5,0.5,1.0125003390961305,0.49691769346932524,-0.01553523003218574,0.8355319015798337,-0.012631207305447053,0.43903408918400866,0.8574131260606436,-0.0158380154592622,-0.014745833308757339,0.22248067872859495,1.0126865912788872,0.5728739592710044,1.0127070327135435,0.5768286142254244,1.01304772275452,0.3731959762216125,0.5,0.5,1.007812923870163,0.49694588476884416,-0.011606537540232175,0.8329749441000018,-0.007976509131808817,0.43958766046426884,0.8547434674621981,-0.01198501932407775,-0.010619791635946672,0.22470518684596152,1.008045739098609,0.5722143023216686,1.0080712908919294,0.5761339531473884,1.00849715344315,0.3743206873811633,0.5,0.5,1.0031255086441955,0.4969740760683631,-0.007677845048278611,0.83041798662017,-0.0033218109581705816,0.440141231744529,0.8520738088637525,-0.0081320231888933,-0.006493749963136006,0.2269296949633281,1.003404886918331,0.5715546453723329,1.0034355490703153,0.5754392920693523,1.0039465841317798,0.3754453985407141,0.5,0.5,0.9984380934182282,0.49700226736788194,-0.003749152556325047,0.8278610291403382,0.0013328872154676542,0.4406948030247892,0.849404150265307,-0.00427902705370885,-0.002367708290325339,0.2291542030806947,0.9987640347380526,0.5708949884229971,0.9987998072487011,0.5747446309913162,0.99939601482041,0.376570109700265,0.5,0.5,0.9937506781922607,0.49703045866740087,0.00017953993562851712,0.8253040716605063,0.00598758538910589,0.4412483743050494,0.8467344916668615,-0.00042603091852440046,0.0017583333824853276,0.23137871119806128,0.9941231825577745,0.5702353314736615,0.994164065427087,0.5740499699132802,0.9948454455090399,0.3776948208598158,0.5,0.5,0.9890632629662933,0.49705864996691973,0.004108232427582082,0.8227471141806746,0.010642283562744126,0.4418019455853096,0.8440648330684161,0.0034269652166600494,0.005884375055295994,0.23360321931542788,0.9894823303774963,0.5695756745243257,0.9895283236054728,0.5733553088352441,0.9902948761976699,0.3788195320193667,0.5,0.5,0.9843758477403259,0.49708684126643865,0.008036924919535646,0.8201901567008427,0.015296981736382362,0.4423555168655698,0.8413951744699706,0.0072799613518444994,0.010010416728106661,0.23582772743279445,0.9848414781972181,0.5689160175749899,0.9848925817838587,0.572660647757208,0.9857443068862999,0.3799442431789175,0.5,0.5,0.9796884325143586,0.49711503256595757,0.01196561741148921,0.8176331992210109,0.019951679910020597,0.44290908814582997,0.838725515871525,0.011132957487028949,0.014136458400917328,0.23805223555016103,0.9802006260169399,0.5682563606256542,0.9802568399622447,0.571965986679172,0.9811937375749299,0.3810689543384683,0.5,0.5,0.9750010172883912,0.49714322386547644,0.015894309903442774,0.8150762417411791,0.024606378083658835,0.44346265942609014,0.8360558572730795,0.014985953622213399,0.018262500073727993,0.24027674366752763,0.9755597738366617,0.5675967036763184,0.9756210981406305,0.5712713256011359,0.9766431682635599,0.3821936654980192,0.5,0.5,0.9703136020624237,0.49717141516499536,0.01982300239539634,0.8125192842613472,0.02926107625729707,0.4440162307063503,0.833386198674634,0.01883894975739785,0.02238854174653866,0.2425012517848942,0.9709189216563836,0.5669370467269828,0.9709853563190164,0.5705766645230999,0.9720925989521898,0.38331837665757,0.5,0.5,0.9656261868364563,0.4971996064645142,0.023751694887349902,0.8099623267815155,0.03391577443093531,0.44456980198661056,0.8307165400761886,0.022691945892582298,0.026514583419349324,0.2447257599022608,0.9662780694761054,0.566277389777647,0.9663496144974022,0.5698820034450638,0.9675420296408198,0.3844430878171209,0.5,0.5,0.9609387716104889,0.49722779776403314,0.027680387379303468,0.8074053693016836,0.038570472604573545,0.44512337326687074,0.8280468814777431,0.02654494202776675,0.03064062509215999,0.24695026801962738,0.9616372172958272,0.5656177328283112,0.9617138726757881,0.5691873423670277,0.9629914603294498,0.3855677989766717,0.5,0.5,0.9562513563845215,0.49725598906355206,0.03160907987125703,0.8048484118218517,0.04322517077821179,0.4456769445471309,0.8253772228792975,0.0303979381629512,0.03476666676497066,0.24917477613699396,0.956996365115549,0.5649580758789755,0.957078130854174,0.5684926812889917,0.9584408910180798,0.3866925101362225,0.5,0.5,0.9515639411585541,0.4972841803630709,0.035537772363210596,0.80229145434202,0.04787986895185002,0.4462305158273911,0.822707564280852,0.03425093429813565,0.03889270843778132,0.25139928425436053,0.9523555129352708,0.5642984189296397,0.9524423890325598,0.5677980202109556,0.9538903217067098,0.3878172212957734,0.5,0.5,0.9468765259325866,0.49731237166258985,0.03946646485516416,0.7997344968621881,0.05253456712548826,0.4467840871076513,0.8200379056824065,0.038103930433320096,0.04301875011059199,0.25362379237172716,0.9477146607549927,0.563638761980304,0.9478066472109458,0.5671033591329195,0.9493397523953397,0.3889419324553242,0.5,0.5,0.9421891107066193,0.4973405629621087,0.04339515734711773,0.7971775393823564,0.057189265299126504,0.4473376583879115,0.817368247083961,0.04195692656850455,0.047144791783402654,0.25584830048909374,0.9430738085747145,0.5629791050309683,0.9431709053893316,0.5664086980548835,0.9447891830839698,0.3900666436148751,0.5,0.5,0.9375016954806519,0.49736875426162763,0.0473238498390713,0.7946205819025245,0.06184396347276474,0.4478912296681717,0.8146985884855156,0.045809922703689,0.05127083345621332,0.2580728086064603,0.9384329563944362,0.5623194480816325,0.9385351635677175,0.5657140369768474,0.9402386137725998,0.3911913547744259];
    
    console.log("length of input data", totalInputData.length)
    //Array of 60 items; one timestep is supposed to consist of 3 items, and the total length of the array should be variable
    let totalOutputData = [1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1,1,0,-1];
    console.log("expected output:", totalOutputData.slice(3, 6));
    console.log("length of output data:", totalOutputData.length)
    
    async function trainNet(inputs, outputs) {
        console.log("training input data:", inputs, "training output data:", outputs);
        console.time("training completed in");
        console.log("number of tensors in memory before training:", tf.memory().numTensors)
        let xs = tf.tensor3d(inputs, [1, inputs.length / 18, 18]);
        let ys = tf.tensor3d(outputs, [1, outputs.length / 3, 3]);
        const res = await net.fit(xs, ys, {
            batchSize: 180,
            epochs: 10
            });
        console.log(res);
        console.log("loss after training:", res.history.loss);
        xs.dispose();
        ys.dispose();
        console.timeEnd("training completed in");
        console.log("number of tensors in memory after training:", tf.memory().numTensors)
    }

    function predictOutput() {
        tf.engine().startScope();
        let predictionData = [0.5,0.5,1.0265625847740325,0.4968331195707686,-0.027321307508046434,0.8432027740193291,-0.026595301826361762,0.43737337534322807,0.86542210185598,-0.02739700386481555,-0.027123958327189335,0.21580715437649517,1.026609147819722,0.5748529301190116,1.026614258178386,0.5789125974595326,1.02669943068863,0.3698218427429599,0.5,0.5,1.0218751695480652,0.49686131087028745,-0.02339261501609287,0.8406458165394973,-0.021940603652723524,0.43792694662348824,0.8627524432575345,-0.0235440077296311,-0.02299791665437867,0.21803166249386177,1.0219682956394436,0.5741932731696758,1.0219785163567718,0.5782179363814965,1.02214886137726,0.3709465539025108];
        console.log("length of prediction data:", predictionData.length);
        //Predicting the output based on a 3d tensor with 2 timesteps (predictionData corresponds to first 36 items in totalInputData)
        let predictionTensor = tf.tensor3d(predictionData, [1, 2, 18]);
        const predictions = net.predict(predictionTensor);
        let netOutput = predictions.dataSync();
        predictions.dispose();
        predictionTensor.dispose();
        //Displaying the last 3 outputs (outputs for the current frame) in the output sequence
        document.getElementById("Output").innerHTML = [netOutput[3], netOutput[4], netOutput[5]];
        tf.engine().endScope();
        net.dispose();
        //1 tensor remains in memory for some reason
        console.log("number of tensors in memory after predicting and disposing net:", tf.memory().numTensors)
    }

    trainNet(totalInputData, totalOutputData).then(predictOutput);
    //Intended output: [1,0,-1]
    //Example output at 1000 epochs: [0.9696730375289917,-0.0013935886090621352,-0.9711311459541321]
</script>