Splitting networks across multiple GPU's, this means you must wait for the slowest node and the longest latency.
As soon as you can remove most of these barriers, compute over non-latency-guaranteed networks becomes more practical, as does non-homogeneous compute (ie. Mixing different GPU models).
It either needs some significant hyperparameter tuning (besides tweaking alpha, which doesn't seem to do much for me), or some fancier initialization (tried both pytorch default and orthogonal, no difference), or maybe my scalar optimizer doesn't work on it (I have a custom optimizer for scalars which speeds up convergence vs Adam, but for DyT layers it seems to be just as good as Adam), or maybe it only catches up after billions of tokens (which I don't have the budget to test for so long).
y = x.to(torch.float32)
y = y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + 1e-6)
z = torch.tanh(self.alpha * x)
scale = (y / (z + 1e-6)).mean(dim = -2).flatten()
self.weight.detach().copy_(scale)
This basically tries to initialize the weights so that the output of DyT is closer to what RMSNorm would have outputted, and it seems to help.For the dataset I just use FineWeb-Edu.
I recommend e.g. the og resnet paper and its follow-up from Kaiming He et al.
For a modern take on RNNs, read https://arxiv.org/abs/2303.06349 by DeepMind.
There essentially the point is that largest eigenvalue (spectral radius) needs to be around 1, meaning repeated applications of a linear transformation doesn’t cause increase or decrease of the activations.
Batch norm and others are important for faster convergence due to forcing the model to focus creating second and higher order nonlinearities, as a simple shift in mean/std is normalized out, and thus the gradient does not point in a direction that would only change those properties of the output distribution.
Surely you would want to compare the output of the LayerNorm without the weight and bias to get an impression on their similarity.
I guess it doesn't matter if the final result works, but I feel like looking at the bit that they are changing in isolation might provide a better insight as to what is happening.
By incorporating DyT, Transformers without normalization can match or exceed the performance of their normalized counterparts, mostly without hyperparameter tuning.
I suppose normalization kernels have reductions in them, but how hard are reductions in 2025?