This sounds super interesting, but as someone who knows little about ML or math in general, could you give an ELI5?

I have a bunch of points. I want to fit a curve to them. I could write a function that takes a bunch of parameters as floats that specify the curve, and an x coordinate as a float, and have it output the most likely y value as a float.

If I have a library, though, that lets me add and multiply not just floats but entire computation subgraphs with the same exact + and * operators, though, I can have the library reverse that function automatically, and say: “optimize the parameters to minimize the difference between the curve and the data points.”

LLMs and other ML systems, to paint with a very broad stroke, solve that problem with billions of parameters in a million-dimensional space. Developing intuition for those high dimensions is hard! But the code is simple because once you’ve done the math for the forward pass, you can go straight from chalkboard to Python code, and the libraries largely assist with reversing and building a GPU-accelerated training process automatically!