斯坦福 CS336 动手大语言模型 Assignment1 BPE Tokenizer & TransformerLM

所有代码更新至 https://github.com/WangYuHang-cmd/CS336/tree/main/assignment1-basics

作业文件结构:

CS336/assignment1-basics/
├── tests/                    # 测试文件目录
│   ├── adapters.py          # 适配器测试
│   ├── conftest.py          # pytest配置
│   ├── __init__.py          # 包初始化
│   ├── snapshots/           # 测试快照
│   ├── test_data.py         # 数据处理测试
│   ├── test_model.py        # 模型测试
│   ├── test_nn_utils.py     # 神经网络工具测试
│   ├── test_optimizer.py    # 优化器测试
│   ├── test_serialization.py # 序列化测试
│   ├── test_tokenizer.py    # 分词器测试
│   └── test_train_bpe.py    # BPE训练测试
│
└── cs336_basics/            # 实现文件目录
    ├── attention.py         # 注意力机制实现
    ├── embedding.py         # 嵌入层实现
    ├── linear.py           # 线性层实现
    ├── optimizer.py        # 优化器实现
    ├── tokenizer.py        # 分词器实现
    ├── transformerLM.py    # Transformer语言模型
    ├── rope.py             # RoPE位置编码
    ├── rmsnorm.py          # RMSNorm层
    ├── softmax.py          # Softmax实现
    ├── swiglu.py           # SwiGLU激活函数
    ├── utils.py            # 工具函数
    └── debug_*.py          # 调试文件

BPE Tokenizer

BPE Class

首先是BPE类, 我们需要正确处理作业已经定义好的接口:

class BPETokenizer:
    def __init__(self, vocab_size: int, special_tokens: list[str] | None = None):
        self.vocab_size = vocab_size
        self.special_tokens = special_tokens or []
        self.special_tokens_bytes = [
            token.encode("utf-8") for token in self.special_tokens
        ]

        self.merges: List[Tuple[bytes, bytes]] = []
        self.stoi: Dict[bytes, int] = {}
        self.itos: Dict[int, bytes] = {}
        self.merges_rank: Dict[Tuple[bytes, bytes], int] = {}

        # init vocab
        for i, token_bytes in enumerate(self.special_tokens_bytes):  # special tokens
            self.stoi[token_bytes] = i
            self.itos[i] = token_bytes

        offset = len(self.special_tokens_bytes)
        for i in range(256):
            self.stoi[bytes([i])] = i + offset
            self.itos[i + offset] = bytes([i])

        self.vocab = self.itos.copy()  # for serialization
        self.merges_rank = {}  # for fast lookup
        # pair2new: (p1, p2) -> new_token_id
        self.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}

其中stoi用来记录每一个toekn对应的token id, itos用来记录每一个token id对应的token, 在初始化的时候我们需要首先载入所有的special_tokens然后再依次将0-255对应字节值载入。

BPE Training

BPE Tokenizer是一个从data中进行学习的一个分词器,其以Byte为单位进行学习, 然后最终学校的结果包括了单词,词根等各种各样的形式。

BPE Tokenizer的核心就是首先经过预分词得到一个token列表, 此时全文被拆成了多个pre_token组成的列表, 然后对这个列表中的special token进行提取(special token不参与合并),我们得到由一整个大列表拆出来的多个小列表,然后我们需要依次统计每一个小列表中的前后相邻的字符pair的个数并计数, 然后按照以下规则进行合并:

1. 首先找到pair计数最多的pair <token1, token2>, 可能会有多个一样数量的pair
2. 然后优先找token1字典序更大的,进行合并
3. 其次找token2字典序更大的进行合并
  • Pre_tokenize

pretokenize这个函数主要用来将文本切分成规范的词块列表,例如

GPT2_SPLIT_PATTERN = (
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pretokenize(text: str) -> list[bytes]:
    str_tokens = re.findall(GPT2_SPLIT_PATTERN, text)
    byte_tokens = [s.encode("utf-8") for s in str_tokens]
    return byte_tokens

例如"Hello world, this is user-123!" 会被pretokenize转换为 [‘Hello’, ’ world’, ‘,’, ’ this’, ’ is’, ’ user’, ‘-’, ‘123’, ‘!’]

Train

我们可以很轻易写出一个暴力的训练方法(见代码中的slow_train函数), 在这个函数中我们
num_merges_needed = self.vocab_size - len(self.stoi) # 需要合并的次数, 每一次合并会扩大vocab_size

for merge_cnt in range(num_merges_needed):
    pair_counts = self._get_stats(token_groups) # 遍历当前的训练列表,统计所有相邻token_id的个数
    best_pair = max( 
        pair_counts,
        key=lambda p: (pair_counts[p], self.itos[p[0]], self.itos[p[1]]),
    )  # 按照合并规则找到需要合并的pair
    # 更新合并后的所有字典
    new_token_id = len(self.itos)
    p1_bytes, p2_bytes = self.itos[best_pair[0]], self.itos[best_pair[1]]
    new_token_bytes = p1_bytes + p2_bytes
    self.merges.append((p1_bytes, p2_bytes))
    self.stoi[new_token_bytes] = new_token_id
    self.itos[new_token_id] = new_token_bytes
    ...

但是这不仅有可能无法通过tests/test_train_bpe.py::test_train_bpe_speed测试(我的暴力解法大约使用了5.8s远大于限制的1.5秒), 在tests/test_train_bpe.py::test_train_bpe_special_tokens 测试中大约使用了将近7分钟。

==================================================== 1 failed, 2 passed in 476.21s (0:07:56) ====================================================

因此我们需要考虑优化这个合并的过程:耗时的大头是 1. “每一次都需要重新统计所有pair” 2.“更新后需要每一次重写当前的token_id序列”, 而这些都可以通过数据结构来优化:对于token_id序列我们可以使用双向链表来构建,然后对于每一个token_id对应的列表的节点位置我们可以存储到token_id为key的set中。然后我们只需要扫一遍整个token_id的序列, 记录每一个pair的个数然后全部push到一个堆中, 这个堆每一次会从堆顶优先pop出我们需要合并的pair. 记住,这里我们并不需要在合并pair后从内部修改这个堆,我们只需要pop出来的时候判断一下当前的pair是否存在或者其计数是否和我们的pair_counts中一致即可。此一次修改合并后我们也只需要将合并后的pair对应的计数重新push进堆即可。考虑到每一次修改的数量不会很多, 因此总的复杂度大约是nlogn级别的
综上我们的思路是:
数据结构:

  1. 维护一个大根堆,里面维护按照BPE合并的顺序进行排序的token_id pair
  2. 维护一个双向链表 用来记录当前的token_id序列
  3. 维护一个为每一个token_id维护一个set,用来存储每一个token_id对应的所有的双向链表的节点的位置

更新方式

1. 从heap顶部取出token_id的pair,判断是否和pair_count中记录的数量一致,
   若不一致则找下一个,直到一致为止,此时就是需要合并的BPE Pair
2. 通过heap中记录的链表节电找到当前当前pair的 pos_idx, nxt_idx然后找到向前向后的链表pre_idx和nnxt_idx
    pre_idx <-> pos_idx <-> nxt_idx <-> nnxt_idx 我们合并后会变成
    pre_idx <-> (pos_idx,nxt_idx) <-> nnxt_idx
    new_token = token[pos_idx] + token[nxt_idx]
    
3. 更新pair_count,遍历pos[token[pos_idx]]的所有链表节点,
   找到所有nxt[]对应的token_id是token[nxt_idx]的位置,然后删除这些位置
    pair_count[(token[pre_idx], token[pos_idx])] - 1    
    pair_count[(token[pre_idx], new_token)] + 1
    pair_count[(token[nxt_idx], token[nnxt_idx])] - 1    
    pair_count[(new_token, token[nnxt_idx])] + 1

    pos[new_token].add(pos_idx)
    pre[nnxt_idx] = pos_idx
    nxt[pos_idx] = nnxt_idx
    pre[nxt_idx] = nxt[nxt_idx] = None # 删除被合并的pair中靠后的那一个token对应的链表

由于python中的heapq默认使用小根堆, 因此我们需要重写一个类来实现大根堆

class PairItem:
    def __init__(self, count, token_id1, token_id2, itos):
        self.count = count
        self.token_id1 = token_id1
        self.token_id2 = token_id2
        self.itos = itos
        self.bytes1 = itos[token_id1]
        self.bytes2 = itos[token_id2]
    
    def __lt__(self, other):
        # 首先按频次降序(大的在前)
        if self.count != other.count:
            return self.count > other.count
        # 频次相同时,按第一个token的字节降序
        if self.bytes1 != other.bytes1:
            return self.bytes1 > other.bytes1
        # 第一个token相同时,按第二个token的字节降序
        return self.bytes2 > other.bytes2
    
    def __eq__(self, other):
        return (self.count == other.count and 
                self.bytes1 == other.bytes1 and 
                self.bytes2 == other.bytes2)
    
    def get_pair(self):
        return (self.token_id1, self.token_id2)

然后我们读取文本直到处理好pretokenize的结果后

# Pre-Tokenizer
assert self.vocab_size >= len(self.stoi)

with open(path, "r", encoding="utf-8") as f:
    text = f.read()

if self.special_tokens:  # Special Token
    special_pattern = f"({'|'.join(re.escape(s) for s in self.special_tokens)})"
    text_parts = re.split(special_pattern, text)
else:
    text_parts = [text]

# Pre-Tokenizer
initial_vocab_map = {v: k for k, v in self.itos.items()}
token_groups = []
for part in text_parts:
    if part in self.special_tokens or not part:
        continue
    words_in_bytes = pretokenize(part)
    for word in words_in_bytes:
        token_groups.append([initial_vocab_map[bytes([b])] for b in word])

首先只需要扫一遍整体的token_id序列进行统计:

# BPE Merge
idx = 0
pair_counts = {}
token = {}
pre = {}
nxt = {}
pos = {}

for i, token_lst in enumerate(token_groups):
    if not token_lst or len(token_lst) <= 1:
        continue
    token_lst_len = len(token_lst)
    for j, token_id in enumerate(token_lst):
        idx += 1
        token[idx] = token_id
        nxt[idx] = None if j == token_lst_len - 1 else idx + 1
        pre[idx] = None if j == 0 else idx - 1
        if j == token_lst_len - 1:
            continue
        token_pair = (token_id, token_lst[j + 1])
        pair_counts[token_pair] = pair_counts.get(token_pair, 0) + 1
        if pos.get(token_pair) is None:
            pos[token_pair] = set()
        pos[token_pair].add(idx)

heap = []
for (a, b), cnt in pair_counts.items():
    item = PairItem(cnt, a, b, self.itos)
    heapq.heappush(heap, item)
然后我们可以开始BPE Merge,merge的顺序和细节需要十分注意,尤其是更新的顺序和对于是否更新的还存在的pair的判断
def update_pair(pair: tuple[int, int], delta: int, pos_idx: int | None = None):
    if pair is None or None in pair: 
        return
    pair_counts[pair] = pair_counts.get(pair, 0) + delta
    cnt = pair_counts[pair]
    if cnt <= 0:
        pair_counts.pop(pair, None)
        pos.pop(pair, None)
        return
    if pos_idx is not None:
        ds = pos.setdefault(pair, set())
        if delta > 0:
            ds.add(pos_idx)
        elif delta < 0:
            ds.discard(pos_idx)
    a, b = pair
    item = PairItem(cnt, a, b, self.itos)
    heapq.heappush(heap, item)

num_merges_needed = self.vocab_size - len(self.stoi)
while num_merges_needed > 0 and heap:
    if not pair_counts: 
        break
    num_merges_needed -= 1
    
    while heap:
        item = heapq.heappop(heap)
        p1, p2 = item.get_pair()
        
        # 检查这个 pair 是否仍然有效
        if (p1, p2) not in pair_counts or pair_counts[(p1, p2)] != item.count:
            continue  # 已经被合并过了

        # merge the new token
        self.merges.append((self.itos[p1], self.itos[p2]))

        p1_bytes, p2_bytes = self.itos[p1], self.itos[p2]
        new_token_bytes = p1_bytes + p2_bytes
        new_token_id = (
            len(self.stoi)
            if self.stoi.get(new_token_bytes) is None
            else self.stoi[new_token_bytes]
        )
        self.stoi[new_token_bytes] = new_token_id
        self.itos[new_token_id] = new_token_bytes

        pos_lst = list(pos.get((p1, p2), set()))
        # modify the token group
        for pos_idx in pos_lst:
            pre_idx = pre[pos_idx]
            nxt_idx = nxt[pos_idx]
            nnxt_idx = nxt[nxt_idx] if nxt_idx is not None else None

            if nxt_idx is None or token[pos_idx] != p1 or token[nxt_idx] != p2: 
                continue

            if pre_idx is not None:
                nxt[pre_idx] = pos_idx  # keep unchanged
                update_pair((token[pre_idx], token[pos_idx]), -1, pre_idx)
                update_pair((token[pre_idx], new_token_id), 1, pre_idx)
            
            if nnxt_idx is not None:
                pre[nnxt_idx] = pos_idx
                update_pair((token[nxt_idx], token[nnxt_idx]), -1, nxt_idx)
                update_pair((new_token_id, token[nnxt_idx]), 1, pos_idx)
            
            pre[pos_idx] = pre_idx
            nxt[pos_idx] = nnxt_idx
            token[pos_idx] = new_token_id
            token[nxt_idx] = None  # remove the old token
            pre[nxt_idx] = None
            nxt[nxt_idx] = None
            
        pair_counts.pop((p1, p2), None)
        pos.pop((p1, p2), None)
        break

self.merges_rank = {pair: i for i, pair in enumerate(self.merges)}
self.vocab = self.itos.copy()
self.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}

然后测试发现最终用时会快很多

============================================================== 3 passed in 30.85s ====================================

其中对于第一个测试从

# 暴力用时
(1752185555.6502326 - 1752185549.8956482) < 1.5 
# 优化之后
tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds

当然除了重载这个堆内的排序方式外,我们还可以手动来写比较字符串时的一个比较方式,只不过需要注意的是我们需要在短的序列末尾补大字符直到和长的一样长(可以手动指定max_len为一个比较大的数,这个的速度也很快)

def bytes_desc(b):
    return bytes(255 - x for x in b)

def pair_desc(pair):
    a = self.itos[pair[0]]
    b = self.itos[pair[1]]
    max_len = 2
    a_pad = a + bytes([0] * (max_len - len(a)))
    b_pad = b + bytes([0] * (max_len - len(b)))
    return (bytes_desc(a_pad), bytes_desc(b_pad))

heap = [
    (
        -cnt,  # 频次取负,freq 高 → 数值小
        pair_desc((a, b)),
        a, b,
    )  # token-1 id, token-2 id
    for (a, b), cnt in pair_counts.items()
]
heapq.heapify(heap)

BPE Encode & Decode

首先是Encode部分, 这个部分需要我们将输入的文本字符串转换为整数ID序列,然后我们需要注意在处理的时候1.特殊token优先处理:先识别并保护特殊token(如<|endoftext|>)2. 按长度排序:避免短特殊token被长特殊token包含的情况 3.分段处理:将文本分割为特殊token和普通文本段落.

我们首先来完成不含有special token的encoder:

    def _encode_ordinary_text(self, text_bytes: bytes) -> list[int]:
        if not text_bytes:
            return []

        try:
            text = text_bytes.decode("utf-8")
        except UnicodeDecodeError:
            text = text_bytes.decode("utf-8", errors="replace")

        ids_out = array("H")  # uint16 足够 ≤ 65k vocab
        pair_rank = self.merges_rank
        pair2new = self.pair2new
        byte2id = self.stoi  # 局部 alias,加速

        # 逐个“词块”处理,避免一次性 list
        for word_b in iter_pretokenize(text):
            token_ids = array("H", (byte2id[bytes([b])] for b in word_b))

            # b. 就地合并:“greedy smallest-rank merge”
            while True:
                best_rank = 1000000000
                best_pos = -1
                # ——— 找当前序列里 rank 最小的 pair ———
                for i in range(len(token_ids) - 1):
                    r = pair_rank.get( # ——— 替换 best_pos & best_pos+1 为新的 token ———
                        (self.itos[token_ids[i]], self.itos[token_ids[i + 1]]),
                        1000000000,
                    )
                    if r < best_rank:
                        best_rank, best_pos = r, i
                if best_pos == -1:
                    break
                
                new_id = pair2new[
                    (self.itos[token_ids[best_pos]], self.itos[token_ids[best_pos + 1]])
                ]
                token_ids[best_pos : best_pos + 2] = array("H", [new_id])

            ids_out.extend(token_ids)

        # array → list
        return ids_out.tolist()

在这里我使用了array而不是list,这样每个token_id只占用2字节,逐个字符处理是防止内存爆炸
然后处理带有特殊字符的encoder:

    def encode(self, text: str) -> list[int]:
        """Encode str"""
        if not text:
            return []

        sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)
        if not sorted_special_tokens:
            return self._encode_ordinary_text(text.encode("utf-8"))

        special_pattern = f"({'|'.join(re.escape(s) for s in sorted_special_tokens)})"
        text_parts = re.split(special_pattern, text)

        all_ids = []
        for part in text_parts:
            if part in self.special_tokens:
                all_ids.append(self.stoi[part.encode("utf-8")])
            elif part:
                all_ids.extend(self._encode_ordinary_text(part.encode("utf-8")))
        return all_ids

对于decode函数则很简单, 我们需要将一个token id序列转换成字符串,按照BPE训练时的合并顺序:

def decode(self, ids: list[int]) -> str:
    all_bytes = b"".join(self.itos.get(id, b"") for id in ids)
    return all_bytes.decode("utf-8", errors="replace")

最后我们需要对BPETokenizer这个类进行一个序列化:

    @classmethod
    def from_serialized(
        cls,
        vocab: dict[int, bytes],
        merges: list[tuple[bytes, bytes]],
        special_tokens: list[str],
    ):
        instance = cls(vocab_size=len(vocab), special_tokens=special_tokens)
        instance.stoi = {v: k for k, v in vocab.items()}
        instance.itos = vocab
        instance.merges = merges
        instance.merges_rank = {pair: i for i, pair in enumerate(merges)}
        instance.vocab = vocab

        instance.pair2new = {(p1, p2): instance.stoi[p1 + p2] for (p1, p2) in merges}

        return instance

测试结果 (注意最后一个点的XFail是正常的 说明你没有作弊…)

============================================================== 3 passed in 30.85s ===============================================================
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_train_bpe.py 

tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_tokenizer.py 

tests/test_tokenizer.py::test_roundtrip_empty PASSED
tests/test_tokenizer.py::test_empty_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_character PASSED
tests/test_tokenizer.py::test_single_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_unicode_character PASSED
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_ascii_string PASSED
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string PASSED
tests/test_tokenizer.py::test_unicode_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string_with_special_tokens PASSED
tests/test_tokenizer.py::test_unicode_string_with_special_tokens_matches_tiktoken PASSED
tests/test_tokenizer.py::test_overlapping_special_tokens PASSED
tests/test_tokenizer.py::test_address_roundtrip PASSED
tests/test_tokenizer.py::test_address_matches_tiktoken PASSED
tests/test_tokenizer.py::test_german_roundtrip PASSED
tests/test_tokenizer.py::test_german_matches_tiktoken PASSED
tests/test_tokenizer.py::test_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_special_token_trailing_newlines PASSED
tests/test_tokenizer.py::test_encode_special_token_double_newline_non_whitespace PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_iterable_memory_usage PASSED
tests/test_tokenizer.py::test_encode_memory_usage XFAIL (Tokenizer.encode is expected to take more memory than allotted (1MB).)

========================================================= 24 passed, 1 xfailed in 4.50s =========================================================

TransformerLM

对于transformerLM我认为着一块的难度比较常规,跟着课程的pdf照着写就可以,不过很适合用来熟悉einops中einsum, reduce和rearrange的用法。以下是一些需要注意的地方。

Rope

这里forward可能会有精度问题,因此需要首先转成torch.float32然后再转回去即可

class RoPE(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None,dtype=None):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len

        self.half_dim = d_k // 2
        freq_seq = torch.arange(self.half_dim, dtype=torch.float32, device=device)
        inv_freq = 1.0 / (theta ** (freq_seq / self.half_dim))
        t = torch.arange(max_seq_len, dtype=torch.float32, device=device)

        freqs = einsum(t, inv_freq, "i, j -> i j")
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)

        self.register_buffer("cos_cached", cos, persistent=False)
        self.register_buffer("sin_cached", sin, persistent=False)


    def forward(
        self,
        x: Float[Tensor, "... seq_len d_k"],
        token_positions: Int[Tensor, "... seq_len"],
    ) -> Float[Tensor, "...  seq_len d_k"]:
        
        assert x.shape[-1] == self.d_k, f"x's last dim {x.shape[-1]} != d_k {self.d_k}"
        assert self.d_k % 2 == 0, "d_k must be even for RoPE"
        
        in_type = x.dtype
        x = x.to(torch.float32)
        
        # (... seq_len d_k) ->  (... seq_len d_pair 2) 2D-Tensor
        x_pair = rearrange(x, "... seq_len (d_pair two) -> ... seq_len d_pair two", two = 2)
        
        # cos/sin tensor build
        cos = self.cos_cached[token_positions]
        sin = self.sin_cached[token_positions]
        rot_mat = torch.stack(
            (
                torch.stack((cos, -sin), dim = -1),
                torch.stack((sin, cos), dim = -1),
            ),
            dim = -2,
        )
        
        # rotate "i j, j -> i"
        x_rot = einsum(rot_mat, x_pair, "... d_pair i j, ... d_pair j -> ... d_pair i")
        out = rearrange(x_rot, "... seq_len d_pair two -> ... seq_len (d_pair two)", two = 2)
        
        return  out.to(in_type)
TransformerBlock

TransformerBlock按照pdf的要求写,注意模块的复用

class TransformerBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        max_seq_len: int,
        theta: float,
        device=None,
        dtype=None,
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)

        self.attn = MultiheadSelfAttentionWithRoPE(
            d_model, num_heads, max_seq_len, theta, device, dtype
        )

        self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
        self.ffn = SwiGLUFFN(d_model, d_ff, device, dtype)

    def forward(
        self,
        x: Float[Tensor, "batch seq_len d_model"],
        token_positions: Int[Tensor, "batch seq_len"] | None = None,
    ) -> Float[Tensor, "batch seq_len d_model"]:
        if token_positions is None:
            token_positions = torch.arange(x.size(1), device=x.device).expand(
                x.size(0), -1
            )

        x = x + self.attn(self.ln1(x), token_positions)
        x = x + self.ffn(self.ln2(x))
        return x
TransformerLM

这里最后不需要返回softmax之后的logits, 返回softmax前一层的tensor即可

class TransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        context_length: int,
        theta: float,
        num_layers: int,
        device=None,
        dtype=None,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.context_length = context_length
        self.theta = theta
        self.num_layers = num_layers
        self.device = device
        self.dtype = dtype

        param_dtype = (
            dtype
            if (
                dtype is not None
                and torch.is_floating_point(torch.tensor([], dtype=dtype))
            )
            else torch.float32
        )

        self.token_embeddings = Embedding(
            vocab_size, d_model, device=device, dtype=param_dtype
        )
        self.layers = MyLayerList(
            [
                TransformerBlock(
                    d_model=d_model,
                    num_heads=num_heads,
                    d_ff=d_ff,
                    max_seq_len=context_length,
                    theta=theta,
                    device=device,
                    dtype=param_dtype,
                )
                for _ in range(num_layers)
            ]
        )
        self.ln_final = RMSNorm(d_model, device=device, dtype=param_dtype)
        self.lm_head = Linear(d_model, vocab_size, device=device, dtype=param_dtype)

    @torch.no_grad()
    def forward(
        self,
        input_indices: Int[Tensor, "batch seq_len"],
        token_positions: Int[Tensor, "batch seq_len"] | None = None,
    ) -> Float[Tensor, "batch seq_len vocab_size"]:
        x = self.token_embeddings(input_indices)

        if token_positions is None:
            token_positions = torch.arange(x.size(1), device=x.device).expand(
                x.size(0), -1
            )

        for layer in self.layers:
            x = layer(x, token_positions)
        x = self.ln_final(x)
        logits = self.lm_head(x)

        return logits
get_batch

get_batch的测试写的不是很完善,这里可以写成保证每一个Epoch rand的数据都不重复

def get_batch(
    dataset: npt.NDArray,
    batch_size: int,
    context_length: int,
    device: torch.device = torch.device("cpu"),
) -> tuple[npt.NDArray, npt.NDArray]:
    B, T = batch_size, context_length
    data_t = torch.as_tensor(dataset, dtype=torch.long, device=device)
    N = data_t.numel()
    
    # starts = torch.randint(0, N - T, (B,), device=device)
    starts = torch.randperm(N - T, device=device)[:B]  # 无放回采样
    offsets   = rearrange(torch.arange(T + 1, device=device), 'n -> 1 n')  # [1, T+1]
    positions = rearrange(starts, 'b -> b 1') + offsets          
    tokens = data_t[positions]          # [B, T+1]
    x, y   = tokens[:, :-1], tokens[:, 1:]   # Next token prediction [B, T]
    return x, y
    
    
class EpochSampler:
    def __init__(self, num_positions: int, device: torch.device):
        self.N = num_positions            
        self.device = device
        self._shuffle()                   

    def _shuffle(self):
        self.perm = torch.randperm(self.N, device=self.device)
        self.cursor = 0                   

    def next(self, k: int) -> torch.Tensor:
        if self.cursor + k > self.N: 
            self._shuffle()
        idx = self.perm[self.cursor : self.cursor + k]
        self.cursor += k
        return idx

def get_batch_without_same(
    dataset: npt.NDArray,
    batch_size: int,
    context_length: int,
    sampler: EpochSampler,
    device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, torch.Tensor]:
    B, T = batch_size, context_length
    data_t = torch.as_tensor(dataset, dtype=torch.long, device=device)   # [N_total]
    N = data_t.numel()

    starts = sampler.next(B)                    # shape (B,)

    # offsets: [1, T+1],数值 0‥T
    offsets = torch.arange(T + 1, device=device).unsqueeze(0)            # (1, T+1)
    # positions: broadcast → (B, T+1)
    positions = starts.unsqueeze(1) + offsets

    tokens = data_t[positions]                  # (B, T+1)
    x, y = tokens[:, :-1], tokens[:, 1:]        # (B, T)

    return x, y

此外我的代码仓库中还提供一些debug函数,可以用来debug tokenizer和bpe_train, 在cs336_basics文件夹下

最后帖一张全部通过的图片:

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值