New, clean implementation of Faster R-CNN in both TensorFlow 2/Keras and PyTorch

Hi everyone,

I recently put the finishing touches on my Faster R-CNN self-learning exercise. My goal was to replicate the model from scratch using only the paper. That was a bit ambitious and I had to eventually relent and peek at some existing implementations to understand a few things the paper is unclear on. The repo is here: GitHub - trzy/FasterRCNN: Clean and readable implementations of Faster R-CNN in PyTorch and TensorFlow 2 with Keras.

I wrote both a PyTorch and a TensorFlow implementation. I’d like to think they are pretty clean, readable, and easy to use. I also documented some of my struggles and takeaways in the file.

One thing that continues to bother me is the need for an additional tf.stop_gradient() in the regression loss functions surrounding a tf.less statement. The function itself is differentiable. The PyTorch version doesn’t need this. I might make a post about it on one of the other sub-forums because I stumbled upon the solution by accident. Without the explicit stop_gradient, the model still learns, but achieves significantly lower precision. Would love to learn about how others would approach debugging such an issue.




@Bart great. You made a lot of effort out there, starred. A small request, in your read-me there’s a lot of development details which can be separated as another .md. In this way, the front can give more top-level highlights for example how-to-reproduce, how-to-fine-tune- or how-to-train-from-scratch-on-custom-data, etc.

Thanks for the suggestions! I can certainly split the Development Learnings into a separate document. Re: fine-tuning and training from scratch on custom data, I suppose new data would have to be provided in the same format as the VOC dataset (which should be fairly simple to do). Do you think it would be worthwhile to elaborate on this point in the or a separate attached document? I was thinking about making an annotation program for generating custom data (probably just an HTML5/JS thing) if the need arises for me.

I think that would be nice and new practitioners will surely find it useful. -)

I’ve tried to use your implementation. I have a question regarding adding new models. The implementation with Vgg16 works nicely & it can be extended to Vgg19 easily too, but adding new models like Resnet50 or mobileNet seems a bit difficult with their skip connections. Is there a possibility for such networks or the implementation only supports straight forward models?
Thanks in advance

I’ve never tried implementing this in TF2 but I figure it must be possible. TF2 must support skip connections. And I think the backbone is relatively self contained, although you may need to add/modify code involved in loading initial params. Also, I believe TF2 has implementations of these models built-in that you could leverage without having to re-define the whole model layer-by-layer yourself. You would just write a new backbone and use it instead of my VGG16 one.

Unfortunately now I’m a bit too busy to do it but if you want to submit a PR I can look at it. I may also tackle this sometime within the next few months as I’m also curious about extending the model to use more modern backbones.