From their code:
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
assert (torch.mm(A, B) - ref).abs().max().item() == 0
I’m sort of surprised that Torch doesn’t have some kind of lazy evaluation thing to avoid computing anything here. I thought that was one of the nice things about all these fancy frameworks (if I wanted the computer to actually do silly things when I asked it to, I would use BLAS directly, right?).
Maybe I'm missing something, but in this case, wouldn't being lazy would be pure overhead? I don't see anything can be lazy here. The reference computed once, nanoseconds before it's needed, and test cases computed at the time of comparison, then tossed away.
What would hope to be achieved by making this case lazy? If you wanted these to run in parallel, with a multi-gpu system, you would use the appropriate parallel interface.
I mean if you wait long enough, it is asking for
of something that can be identified as definitionally zero.I don't understand. Since it's not using the parallel interface, only one operation can happen at a time. This would be, literally, sequential execution with extra overhead, in this case. Again, in this case, what would hope to be achieved from doing things lazily, since the lazy operations would immediately be followed by their evaluation?
The parallel interface, which is async, is probably what you're lookin for.
Let's look at the subtraction in this case.
If evaluation is lazy, then the subtraction operator gets fed two unevaluated matrix multiplies.
If it's a dumb subtraction operator, this gives us no benefit. Eventually it evaluates both and then subtracts. And it has some extra overhead like you said.
But if it's a smart subtraction operator, it can realize that both parameters are the same equation, and then it can return all 0s without evaluating anything.
And even better than just skipping the matrix math, "all 0s" can be a stub object that takes O(1) time to set up. And then .abs().max() will be instant too.
I see now, thank you. I was stuck on the "lazy evaluation" part, rather than the optimization part they were actually suggesting.
The Python commands are encountered sequentially. One could image a library where the Python commands build the computation under the hood. Then, the library would be able to take advantage of situations like this one (or, more practically, reorder multiplications and/or avoid unnecessary temporaries).