CS336 Assignment 1
Implemented RoPE. This was pretty tricky. Initially I tried implementing with a lot of einsum and it got confusing to get the rotation to apply correctly. Eventually got it, but it needed to do batched 2x2 matmuls to rotate the components. After getting it to work, I rewrote it to just use the formula for each component of the subvector (x, y) which was simpler and more efficient. I also didn’t store every 2x2 but just the sin and cos values since the rotation just required cos, sin and the negative of the sin value.
I also implemented softmax which was pretty straightforward. Scaled dot product attention was also easy. Though it was easier for me to write it with matmul and transpose rather than einsum. I think potentially the rope shenanigans corrupted the automaticity of einsum and I ended up tripping.
Implemented RMSNorm — pretty straightforward. Implemented SiLU, GLU, and SwiGLU. I think I did it right? Passed the tests, but not sure if I was supposed to compose the SwiGLU with the GLU more directly than I did.
Didn’t get to spend too much time on this today. Completed the Linear module and Embedding module implementations. For the Linear module, got more reps with einsum. For the Embedding module, it was easy to implement indexing into tensors, supplying a tensor of integers to access into another tensor and pull out values.
I completed the BPE tokenizer part of assignment 1. It was good to get the python flowing again. The trickiest part was trying to understand the precise definition of the tokenization algorithm. One part I got a little confused about was whether applying merges happens sequentially within a pretoken. Another tricky part was that the chr function doesn’t output the byte corresponding to the numbers after 128 but instead a multi byte representation. And then bytes(number) yielded a number long byte string. Took a minute to remember to do bytes([number]) to get the single byte.
It was useful to implement slow versions of the training and encoding and identify the bottlenecks and optimize the implementations to hundred x the performance. The final implementations I ended up with were quite a bit faster than naive and far faster than the requirements for the problem, but also still have much room for improvement by reducing the unnecessary memory allocations and using not python. But I feel that it won’t be the best use of time to chase the most optimal versions of those.
I was surprised that the pretokenization regex bakes in some big priors about the dataset, splitting on punctuation, whitespace, etc. I kind of expected modern BPE to be more general and less specific to text, but it turns out not.
Next up will be implementing parts of the transformer language model components and surrounding training pieces. I’m excited to implement these pieces with the “from scratch” in pytorch approach using only parameters and modules. I imagine it being a little less tedious than the BPE training, but we’ll see.