import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
import argparse
import json
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from transformers import PreTrainedTokenizerFast
from tokenizers.pre_tokenizers import Split
from tokenizers import normalizers
from tokenizers.normalizers import NFKC, StripAccents
import thulac
from configs.module_factory import ModuleFactory
from common.utils.validation_util import positive_int
# 定义分词函数
def split_fn(text):
thu = thulac.thulac(seg_only=True) # 仅分词
return list(thu.cut(text, text=True))
# 读取JSONL文件并提取文本数据
def read_texts_from_jsonl(file_path):
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
text = data.get("text", "").strip()
if text: # 跳过空文本
yield text
def train_tokenizer():
data_path = f"{root_dir}/common/dataset/pretrain_hq.jsonl"
texts = read_texts_from_jsonl(data_path)
# 1. 定义特殊token
special_tokens = [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>",
"[UNK]",
]
# 2. 初始化WordPiece tokenizer
tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
# 3. 添加中文规范化
tokenizer.normalizer = normalizers.Sequence(
[NFKC(), StripAccents()] # 兼容繁体/简体 # 去除音调符号
)
# 4. 设置预处理器(中文需特殊处理)
tokenizer.pre_tokenizer = Split(split_fn, behavior="removed")
# 5. 配置WordPiece训练器
trainer = WordPieceTrainer(
vocab_size=args.vocab_size, # 词汇表大小
special_tokens=special_tokens,
min_frequency=args.min_frequency, # 最小词频
show_progress=True,
continuing_subword_prefix="", # 子词前缀(关键参数)
limit_alphabet=args.limit_alphabet, # 字符集大小限制(中文需增大)
)
# 6. 开始训练
tokenizer.train_from_iterator(texts, trainer=trainer)
# 7. 检查特殊token的索引
assert tokenizer.token_to_id("<|endoftext|>") == 0
assert tokenizer.token_to_id("<|im_start|>") == 1
assert tokenizer.token_to_id("<|im_end|>") == 2
assert tokenizer.token_to_id("[UNK]") == 3
# 8. 保存并转换为Transformers格式
chat_template = "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
tokenizer_dir = f"{root_dir}/common/tf_tokenizer/"
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer_file = os.path.join(tokenizer_dir, "/tokenizer.json")
tokenizer.save(tokenizer_file)
transformers_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_file,
bos_token="<|im_start|>",
eos_token="<|im_end|>",
pad_token="<|endoftext|>",
unk_token="[UNK]",
tokenize_chinese_chars=True, # 关键中文处理选项
chat_template=chat_template,
)
transformers_tokenizer.save_pretrained(tokenizer_dir)
print("Tokenizer training completed and saved.")
def eval_tokenizer():
from transformers import AutoTokenizer
# 加载预训练的tokenizer
tokenizer = AutoTokenizer.from_pretrained(f"{root_dir}/common/tf_tokenizer/")
# 获取实际词汇表长度(包括特殊符号)
actual_vocab_size = len(tokenizer)
print("tokenizer实际词表长度:", actual_vocab_size)
messages = [
{"role": "system", "content": "你是一个优秀的聊天机器人,总是给我正确的回应!"},
{"role": "user", "content": "你来自哪里?"},
{"role": "assistant", "content": "我来自地球"},
]
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print(f"原始文本:{new_prompt}")
model_inputs = tokenizer(new_prompt)
print("encoder长度:", len(model_inputs["input_ids"]))
input_ids = model_inputs["input_ids"]
response = tokenizer.decode(input_ids, skip_special_tokens=False)
print(f"decoder文本:{response}")
print("decoder和原始文本是否一致:", response == new_prompt)
def get_parser():
global root_dir
module = ModuleFactory()
root_dir = module.get_root_dir()
parser = argparse.ArgumentParser(description="Transformers Tokenizer train")
parser.add_argument(
"--min_frequency",
type=positive_int,
default=2,
help="过滤低频词",
)
parser.add_argument(
"--limit_alphabet",
type=positive_int,
default=1000,
help="字符集大小",
)
parser.add_argument(
"--vocab_size",
type=positive_int,
default=6400,
help="词表大小:6400、16384、32768、65536",
)
parser.add_argument(
"--jieba_parallel",
type=positive_int,
default=32,
help="jieba并行线程数",
)
return parser
def main(remaining_args=None):
global root_dir
module = ModuleFactory()
root_dir = module.get_root_dir()
global args
parser = get_parser()
args = parser.parse_args(remaining_args)
args.description = parser.description
train_tokenizer()
eval_tokenizer()
if __name__ == "__main__":
main()
以上代码出现TypeError(“argument ‘pattern’: failed to extract enum PyPattern (‘str | tokenizers.Regex’)\n- variant Str (str): TypeError: failed to extract field PyPattern::Str.0, caused by TypeError: ‘function’ object cannot be converted to ‘PyString’\n- variant Regex (tokenizers.Regex): TypeError: failed to extract field PyPattern::Regex.0, caused by TypeError: ‘function’ object cannot be converted to ‘Regex’”)
最新发布