GNN for tensorflow.js?

I might have missed it but is there going to be GNN released for JavaScript any time soon?

1 Like

This is a good question. AFAIK I am unware of any timeline for this in the immediate future but maybe @pyu or @lina128 knows if that is in the pipeline or not for 2022.

Thanks for the info. Still hoping to hear a response from those you tagged but for my 2c most of the useful use cases seem to rely on GNN since most structured data is a graph. Analysing images is great and all but GNN seems MUCH more useful than anything else and so great if it can be prioritised. Also, .js preferred over python

Totally with you there - JS FTW!

We are a younger team so there is a continuous game of replicating the most needed features from TF Research to JS.

On a positive note though TensorFlow.js has been growing in usage exponentially last few years as JS developers realize the potential here - for which ~70% of devs use JS as a primary language so a huge potential future community in the making as ML reaches more folk.

Thanks again for your enthusiasm in TensorFlow.js! In the meantime while we wait for a reply as things are slow right now given it is new years eve tomorrow, if you find any other folk interested in this topic area for JS, please do direct them here so we can find out if any folk may be interested in helping to implement this. Just in case the team does not have capacity to do in the immediate future this could be a community contribution to the product as part of the TensorFlow.js SIG (Special Interests Group).

In addition, I have started a feature request on github below which you can also point people to so we can monitor how much people want/need this feature to help prioritize in the future. Feel free to share with others:

Feel free to add to this open feature request with more details of what would be the most crucial parts to have first etc if there is a specific sub part of this that is more important than others etc.


I’m happy to help. I’m a pretty decent JS developer, degree in computing science, 30 yrs of coding overall, been building a couple of Firebase/GCP/react apps lately. Happy If you want to screen me to see if I’m up to helping. Main interest is applying GNN to Archimate models, which are graphs. Have built a few simple things with ML toolkits, and binge watched many ML videos, so I’d be assisting as someone who can probably climb the learning curve as a keen newb

1 Like

Please do leave a comment on the github feature request so our SWEs know you exist and are willing to help out!

Hi Jason, I left a message.

Since then i’ve had some success implementing node prediction, edge prediction, and graph classification with tensorflow.js, but wondering if there’s some simple non-maths equation based explanation of message passing.


if there are 3 nodes, A, B, and C, and A <-> B, and A <-> C

so first 3 slots are the onehot encoding of A, B, C, and last 2 slots are some features, say.
A = [0, 0, 1, 0, 1]
B = [0, 1, 0, 1, 1]
C = [1, 0, 0, 1, 1]

then after 1 round of message passing, what does A, B, C look like.


A = average of A + B + C = [0.33, 0.33, 0.33, 0.66, 1]
B = average of B + A = [0, 0.5, 0.5, 0.5, 1]
C = average of C + A = [0.5, 0.5, 0, 1, 1]

and then round 2 would repeat the above.

I’ve actually tried something like this but not getting good results.

It’d be awesome if you or one of the team could tell me what I’m not understanding here with GCN embeddings, and even write out in a reply what the embeddings for A, B, and C might look like after each of the 2 message passing rounds

It seems (?) like the above is wrong because the identity of the source node is diluted too much in each message passing round.

The benefit of the Google team or someone else doing this is that this information does not seem to exist on the Internet at all, and if it’s a blocker for me (as a developer type, not a maths types) then you can bet it’s a blocker for most other developer types trying to do GCNs.


1 Like

I am unsure if anyone on the TensorFlow.js team has looked in to GNNs yet. I shall ask and see but this is not something we have investigated yet as far as I know, or maybe someone else on the TF team who knows about GNNs can give a language agnostic answer for the sake of the explanation you need to understand to implement.

@Laurence_Moroney @pyu Who is the GNN expert on the TF team?

I need GCN help specifically. I’m ok with GNNs in that they’re pretty much the same as anything else.

I just don’t get how to apply the convolution / message passing.

i.e. is the tensor for a node after each message passing:

  1. the average of itself and its neighbours, or
  2. the concatenation of itself and the average of its neighbours.

There is literally no simple explanation of this on the Internet anywhere.

Hi @Greg_Matthews ,

this might not help but I came a post on observable by Aman Tiwari addressing an implementation of a GCN with TF.js, it’s based on an older post by Thomas Kipf (featured in the observable post) discussing GCN’s which you may have already seen.

I’m not sure if either of these help your case but I’m interested in learning how to enact basic Graph transformations with TF.js and your ‘success implementing node prediction, edge prediction, and graph classification’ (earlier) sounds very interesting (do you have a blogpost on this somewhere??). I have tried to translate the observable post to codepen here by rearranging the code but come up against a Promise error that’s beyond my level—though I’m not even sure it is something I could use even if I managed to get it to work (like a dog chasing a car).

I’ve cobbled together a graph editing UI on top of d3.js and am ultimately aiming for an experimental interactivity (something like autocomplete but more creative) co-creating graphs as they are made.

Any help appreciated thanks.

Hi Samson,

I don’t have a blog post. I’m happy to try to answer any questions. I’ve implemented basic GNN functionality, not GCN so far – so will definitely study the post link you provided.

The node prediction works by using xs/inputs of 2 vectors concatenated together, being a source node vector and edge vector, and the label is the target node vector. I also determine all the combinations of source + edge + target that are NOT used and add them in as training data to produce a ‘None’ output (which is just another slot in the label vector) – this is because softmax and categoricalcrossentropy seems to always want to try really hard to give a good answer, and you need to tell it when there’s no good answer for the prediction you’re asking it to make.

Similar approach for edge prediction - the xs/input vector is a concatentation of source + target nodes, and the labels vector is the edge.

Happy to clarify anything. Let me know.

1 Like

Hi @Greg_Matthews ,

Thank you, that sounds comprehensive and there are a few points I need to unpack—at my amateur level.

I was so far thinking based on an input graph to use a depth first search of say 4 nodes deep to generate arrays (xs/input per node) of training data. When you say all combinations NOT used is that essentially subtracting the full training set from a set of EVERY permutation of unconnected nodes? Here I hit a block already.

But if there’s a chance you could paste up an MVP sketch as a codepen (or even jsfiddle? but maybe not observable), github, or something else more convenient that would be incredible, and easier for me to integrate. And though I know it would be a hassle a Medium post would I’m sure get a lot of attention—but it’s understandable if you’d prefer not.

Failing this I’ll try to come back with a codepen link interpreting your description if I can dev something not too embarrassing.

Not on TF but in DeepMind: Aleksa Gordić Appears to be one of the hottest on GNN’s
Could be worth asking [but I don’t mean to interlope].

Hi Sam,

re your “When you say…” question, yes that’s what i’m doing. It’s very fast though, even with 60+ node types and 10+ edge types.

I then add this to the actual triples found (source + edge + target), shuffle it, and use a validation split of 20%.

I also found that if I autoencode the xs I seem to get a better training result, i.e. crunch down the quite sparse xs vector which is 120+ or 70+ wide depending on if i’m doing edge prediction or node prediction respectively, down to a vector that’s maybe 10 wide. To do the autoencoding you just use the xs as both the xs + labels, and make sure the middle layer is thinner than the input or output layers - Tensorflow.js Dimensionality Reduction (Autoencoder) / j.carson / Observable

I’ve found Zak Jost on the Graph Machine Learning discord channel (and he has a YouTube as well) to be super helpful, and Zak pointed out I was doing “Simple GCN” when I was trying to get that to work, so I can do context sensitive edge/node predictions.

Simple GCN is when you just do 2 rounds of message passing (MP), and then use the resultant vectors to pass through a regular neural net.

“Proper GCN” is when each you seem to require custom Tensorflow.js layers that perform the message passing, which is a bit beyond me right now but that’s my next step to try.

I think you’re generally ahead of me on “Proper GCN” with following the matrix algorithms, etc.

Hi @Greg_Matthews ,

I’ll follow the leads, Zak’s work looks quite comprehensive and I see he’s about to launch two GNN courses, so I’ll take a look at this and Carson’s work whilst upping my TF and matrix coding ability. Regarding Proper GCN I’m really not even caught up with anyone let alone ahead—I still have a lot of basic ground to cover.
Many thanks!

If you find any tensorflow.js code that shows how to build a custom GCN layer that does message passing let me know.

I’ve been messing around integrating GPT-3 into my app as well now, which is hilarious, when you feed it a bunch of data and ask it questions.

Sounds interesting, and absolutely—will keep an eye out specifically!
If I figure out an MVP TF.js GNN prediction model I might even have a go at writing a first Medium post. Until then good luck and keep experimenting!

Seems there is a lot of interest around GNNs. I believe @rishit_dagli May also be interested in this too. It may be wise for you folk to get together and check your various skills and see how this could come to life. If enough people and willingness to implement this could also at some point potentially turn into a TensorFlow.js SIG research area led by the folk here.

@pyu What do you think?

1 Like

Just to clarify, do we need GNN or GCN?

I’ve implemented what Zak Jost calls “Simple GCN” where I do 2 rounds of message passing and then pretty much implement a normal NN.

I think it’d be good if we can clarify the extent of any GNN/GCN work. For me, it seems the best implementation is where you have a graph of neural nets that implement message passing.

This aligns I think with the Stanford/Jure Leskovec lectures on how to do a GCN, basically “proper GCN”

Right now, with TensorFlow.js this seems to require building a custom layer, and then assembling these in a graph structure, rather than the usual tf.sequential call.

It seems like all the building blocks might be there, but to make a GCN implementation in tensorflow.js that’s easy to use, the pieces would seem to be:

  1. A graph structure provided by tensorflow.js that you populated with your embeddings. This seems useful because having a standardised graph structure seems to enable TensorFlow.js to simplify the rest of the GCN process.
  2. Ability to then do tf.graph(myGraphStructure) which would be analogous to tf.sequential()
  3. Ability to plug in your particular message passing function, so you could override how neighbor embeddings are combined into the source embedding.
  4. Ability to save/load, as per the current functionality.
  5. Perhaps ability to support random walk where the graph won’t fit into memory – but perhaps the random walk can occur in step 1 where you only grab a subset of the graph.

This might provide an initial set of requirements for others to correct/improve.


Do let us know how you get on there Greg. Interested to see what you end up creating here.