loss_mask的设置
loss_mask[input_tokens == eos_token_id] = 0
假设有一个原始序列: a b c d e f g
input: a b c d e f
output: b c d e f g
假设c是eos token,则上述操作是把output中的d mask掉,即不预测eos之后的token。在随机拼接的情况下,这种token一般没啥含义。如果保留了应该也没啥影响。
SFT时候loss mask设置
SFT时候如果要对input进行mask,则
错误:loss_mask[:seq_len] = sample['loss_mask'][:-1]
正确:loss_mask[:seq_len] = sample['loss_mask'][1:]
a b c d e f g
0 0 1 1 1 1 1 这里表示我们需要预测 c d e f g
input: a b c d e f
0 0 1 1 1 1 表示从第2个位置(下标0开始)开始才预测,即从d开始预测,所以错误
output: b c d e f g
0 1 1 1 1 1