Building Llama from scratch.

Getting into the code, finetuning and performance.

Overview

LLaMA 2.0 , the groundbreaking innovation from Meta AI, has revolutionized the AI landscape as a leading, high-performing, open-source pre-trained language model. What sets LLaMA apart is its remarkable capability to outperform the colossal GPT-3 (175B) while being just a fraction of its size (7-70B). This impressive feat has not only piqued the interest of the global tech community but also led to a plethora of open-source models in the past few months. In this blog, we will delve deeper into the model’s architecture as I implement some of the most crucial components of the LLaMA 2.0 model. This hands-on approach aims to dissect and understand the intricate architecture and innovative strategies that empower LLaMA, providing insights into what makes it a technological powerhouse in today’s AI landscape.

Nearly all casual language models adopt the decoder-only variant of the transformer architecture. However, LLama-2 introduces a few crucial modifications to this architecture that is optimized for faster inference. Some of the key changes include:

Figure 1: Architecture overview of LLaMA (Source: Umar Jamil)

Let’s now embark on a more hands-on journey. The upcoming sections are dedicated to unraveling the intricacies of these refinements. We’ll delve into each component, understanding its theoretical foundation and practical implications. Then, armed with this knowledge, we’ll attempt to code these features from scratch. This exercise aims not just to replicate these sophisticated mechanisms but to decode their underlying principles, enabling us to appreciate the craftsmanship behind LLaMA 2’s design.

RMS normalization

Figure 2: LLaMA-2 uses a pre-normalization variant of the normal transformer block. (Source: Deep(Learning) Focus)

Traditional Layer Normalization (LN) stabilizes the mean and variance of layer inputs, enhancing training speed and stability across various tasks. However, LN’s computation of both mean and variance can be complex and computationally expensive, especially for deeper networks. RMSNorm normalizes activations based solely on the root mean square, thereby avoiding the separate computation of the mean, which is required in LN.

\[\begin{equation} \bar{a_i} = \frac{a_i}{\text{RMS}(a)}g_i, \quad \text{where} \quad \text{RMS}(a) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2}. \end{equation}\]

This simplification leads to reductions in computational overhead and potential improvements in training dynamics. Experimental results in the original paper show that across different models, RMSNorm yields comparable performance against LayerNorm but shows superiority in terms of running speed with a speed-up of 7 % ∼ 64 %.

Further, in LLaMA, RMSNorm is used differently; it’s applied at the start of transformer block layers, known as pre-normalization, and improves effectiveness by enhancing stability and learning efficiency. This placement, depicted in Figure 2, helps manage the data flowing through the network, making it easier for the model to learn from complex datasets.

The code below normalizes the input x by its root mean square for scale invariance, then scales each normalized value with a learned parameter self.weight, adjusting value magnitudes based on this scaling factor.


# Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467)
# borrowed from the official Llama implementation:
# https://github.com/facebookresearch/llama/blob/main/llama/model.py
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.
        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Compute the root mean square normalization. Use Equation 4 under
        Section 4 of https://arxiv.org/abs/1910.07467 as a reference. Add 
        the given epsilon value (self.eps) to the tensor's norm (i.e. inside
        the square root in Equation 4) before normalizing the tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.
        """
        norm_x = x.norm(2, dim=-1, keepdim=True)
        d_x = x.shape[1]
        rms_x = torch.sqrt(norm_x * d_x + self.eps)
        x_normed = x/rms_x

        return x_normed

    def forward(self, x):
        """
        Apply the root mean square normalizer.
        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.
        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight