You could just learn the right estimated number of recursions, also passing 'backtracking'/'state' information at the next nested level. Kind of like how state space models encode extractible information via a basis function representation, you could encode extractible recursion state information into the embedding. See also transformers that can learn to recognize n-deep balanced parentheses (Dyck-n languages)