Getting into the code, finetuning and performance.
LLaMA 2.0
Nearly all casual language models adopt the decoder-only variant of the transformer architecture. However, LLama-2 introduces a few crucial modifications to this architecture that is optimized for faster inference. Some of the key changes include:
Let’s now embark on a more hands-on journey. The upcoming sections are dedicated to unraveling the intricacies of these refinements. We’ll delve into each component, understanding its theoretical foundation and practical implications. Then, armed with this knowledge, we’ll attempt to code these features from scratch. This exercise aims not just to replicate these sophisticated mechanisms but to decode their underlying principles, enabling us to appreciate the craftsmanship behind LLaMA 2’s design.
Traditional Layer Normalization (LN) stabilizes the mean and variance of layer inputs, enhancing training speed and stability across various tasks. However, LN’s computation of both mean and variance can be complex and computationally expensive, especially for deeper networks. RMSNorm
This simplification leads to reductions in computational overhead and potential improvements in training dynamics. Experimental results in the original paper show that across different models, RMSNorm yields comparable performance against LayerNorm but shows superiority in terms of running speed with a speed-up of 7 % ∼ 64 %.
Further, in LLaMA, RMSNorm is used differently; it’s applied at the start of transformer block layers, known as pre-normalization, and improves effectiveness by enhancing stability and learning efficiency. This placement, depicted in Figure 2, helps manage the data flowing through the network, making it easier for the model to learn from complex datasets.
The code below normalizes the input x by its root mean square for scale invariance, then scales each normalized value with a learned parameter self.weight, adjusting value magnitudes based on this scaling factor.
# Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467)
# borrowed from the official Llama implementation:
# https://github.com/facebookresearch/llama/blob/main/llama/model.py
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Compute the root mean square normalization. Use Equation 4 under
Section 4 of https://arxiv.org/abs/1910.07467 as a reference. Add
the given epsilon value (self.eps) to the tensor's norm (i.e. inside
the square root in Equation 4) before normalizing the tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = x.shape[1]
rms_x = torch.sqrt(norm_x * d_x + self.eps)
x_normed = x/rms_x
return x_normed
def forward(self, x):
"""
Apply the root mean square normalizer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight