Encoding.word_ids()

该文展示了如何利用transformers库的AutoTokenizer对文本进行编码,具体操作涉及加载预训练模型google/bigbird-roberta-base,保存tokenizer,并对输入字符串ilovingyouHaarde进行编码和解码。过程中提到了wordPiece分词方法。

一、Encoding资料

Encoding

二、代码 

from transformers import AutoTokenizer

DOWNLOADED_MODEL_PATH = 'model'
tokenizer = AutoTokenizer.from_pretrained('google/bigbird-roberta-base')
tokenizer.save_pretrained('model')

encoding = tokenizer("i loving you Haarde")

print(encoding['input_ids'])
print(tokenizer.decode(encoding['input_ids']))

print(encoding.word_ids())
print(tokenizer.convert_ids_to_tokens(encoding['input_ids']))

三、输出

[65, 1413, 14543, 446, 9499, 45194, 66]

[CLS] i loving you Haarde[SEP]

[None, 0, 1, 2, 3, 3, None]
['[CLS]', '▁i', '▁loving', '▁you', '▁Ha', 'arde', '[SEP]']

 这一过程使用了wordPiece

参考文献:保姆级教程,用PyTorch和BERT进行命名实体识别

#位置编码 class TransformerEmbedding(nn.Module): def __init__(self, config): super().__init__() # hyper params self.vocab_size = config["vocab_size"] self.hidden_size = config["d_model"] # 词向量维度 self.pad_idx = config["pad_idx"] dropout_rate = config["dropout"] self.max_length = config["max_length"] # layers,设置padding_idx可以让pad的词向量全为0 self.word_embedding = nn.Embedding( self.vocab_size, self.hidden_size, padding_idx=self.pad_idx ) self.pos_embedding = nn.Embedding( self.max_length, self.hidden_size, _weight=self.get_positional_encoding( self.max_length, self.hidden_size ),# 位置编码,权重通过get_positional_encoding函数计算得到 ) self.pos_embedding.weight.requires_grad_(False) # 不更新位置编码的权重 self.dropout = nn.Dropout(dropout_rate) # 随机失活层 def get_word_embedding_weights(self): return self.word_embedding.weight # 计算位置信息 @classmethod def get_positional_encoding(self, max_length, hidden_size):#max_length是最大长度,hidden_size是embedding维度相等 # Compute the positional encodings once in log space. pe = torch.zeros(max_length, hidden_size) # 初始化位置编码 # .unsqueeze(1) 是将这个一维张量转换为二维张量,即将其形状从 (max_length,) 变为 (max_length, 1)。这个操作在张量的维度上增加了一个维度,使其从一维变为二维,第二维的大小为 1。 position = torch.arange(0, max_length).unsqueeze(1) # 位置信息,从0到max_length-1 div_term = torch.exp( torch.arange(0, hidden_size, 2) * -(torch.log(torch.Tensor([10000.0])) / hidden_size) )# 计算位置编码的权重,为了性能考量(是数学上的对数函数分解) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe def forward(self, input_ids): # input_ids: [batch_size, seq_len] seq_len = input_ids.shape[1] assert ( seq_len <= self.max_length ), f"input sequence length should no more than {self.max_length} but got {seq_len}" position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) print(position_ids) #为了调试 # embedding word_embeds = self.word_embedding(input_ids) # 词嵌入 pos_embeds = self.pos_embedding(position_ids) # 位置编码 embeds = word_embeds + pos_embeds embeds = self.dropout(embeds) return embeds def plot_position_embedding(position_embedding):# 绘制位置编码 plt.pcolormesh(position_embedding) # 绘制位置编码矩阵 plt.xlabel('Depth') plt.ylabel('Position') plt.colorbar() # 颜色条,-1到1的颜色范围 plt.show() position_embedding = TransformerEmbedding.get_positional_encoding(64, 128) plot_position_embedding(position_embedding)这段transformer代码添加注释,详细说一下代码什么意思
03-26
from transformers import BertTokenizer, BertModel import torch from sklearn.metrics.pairwise import cosine_similarity # 加载BERT模型和分词器 tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') model = BertModel.from_pretrained('bert-base-chinese') # 种子词列表 seed_words = ['个人信息', '隐私', '泄露', '安全'] # 加载微博用户文本语料(假设存储在weibo1.txt文件中) with open('output/weibo1.txt', 'r', encoding='utf-8') as f: corpus = f.readlines() # 预处理文本语料,获取每个中文词汇的词向量 corpus_vectors = [] for text in corpus: # 使用BERT分词器将文本分成词汇 tokens = tokenizer.tokenize(text) # 将词汇转换为对应的id input_ids = tokenizer.convert_tokens_to_ids(tokens) # 将id序列转换为PyTorch张量 input_ids = torch.tensor(input_ids).unsqueeze(0) # 使用BERT模型计算词向量 with torch.no_grad(): outputs = model(input_ids) last_hidden_state = outputs[0][:, 1:-1, :] avg_pooling = torch.mean(last_hidden_state, dim=1) corpus_vectors.append(avg_pooling.numpy()) # 计算每个中文词汇与种子词的余弦相似度 similarity_threshold = 0.8 privacy_words = set() for seed_word in seed_words: # 将种子词转换为对应的id seed_word_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(seed_word)) # 将id序列转换为PyTorch张量,并增加batch size维度 seed_word_ids = torch.tensor(seed_word_ids).unsqueeze(0) # 使用BERT模型计算种子词的词向量 with torch.no_grad(): outputs = model(seed_word_ids) last_hidden_state = outputs[0][:, 1:-1, :] avg_pooling = torch.mean(last_hidden_state, dim=1) seed_word_vector = avg_pooling.numpy() # 计算每个中文词汇与种子词的余弦相似度 for i, vector in enumerate(corpus_vectors): sim = cosine_similarity([seed_word_vector], [vector])[0][0] if sim >= similarity_threshold: privacy_words.add(corpus[i]) print(privacy_words) 上述代码运行后报错了,报错信息:ValueError: Found array with dim 3. check_pairwise_arrays expected <= 2. 怎么修改?
05-31
import tkinter as tk from tkinter import ttk, scrolledtext, messagebox import subprocess import threading import re import time import os from datetime import datetime class DeviceManager: """设备管理类""" def __init__(self): self.devices = {} # {device_id: {'status': 'idle', 'session_id': None, 'last_test': None}} self.lock = threading.Lock() def refresh_devices(self): """刷新设备列表""" try: result = subprocess.run( ["adb", "devices"], capture_output=True, text=True, encoding="utf-8" ) new_devices = {} for line in result.stdout.splitlines()[1:]: if line.strip() and "device" in line: device_id = line.split("\t")[0] new_devices[device_id] = self.devices.get( device_id, {'status': 'idle', 'session_id': None, 'last_test': None} ) with self.lock: self.devices = new_devices return True except Exception as e: return False, str(e) def update_device_status(self, device_id, status, session_id=None): """更新设备状态""" with self.lock: if device_id in self.devices: self.devices[device_id]['status'] = status if session_id: self.devices[device_id]['session_id'] = session_id if status == 'completed': self.devices[device_id]['last_test'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") def get_selected_devices(self, selected_ids): """获取选中的设备信息""" with self.lock: return {did: info for did, info in self.devices.items() if did in selected_ids} class GMSTestAssistant(tk.Tk): def __init__(self): super().__init__() self.title("GMS测试助手 v3.0 - 多设备管理") self.geometry("1100x800") self.device_manager = DeviceManager() # 创建主框架 self.create_widgets() self.refresh_devices() # 测试状态 self.test_running = False self.active_threads = {} # 定时刷新设备状态 self.after(5000, self.periodic_refresh) def create_widgets(self): """创建界面组件""" # 设备管理面板 device_frame = ttk.LabelFrame(self, text="设备管理") device_frame.pack(fill="x", padx=15, pady=10) # 设备列表控件 self.device_tree = ttk.Treeview( device_frame, columns=("status", "session", "last_test"), show="headings", height=8 ) self.device_tree.heading("#0", text="设备ID") self.device_tree.heading("status", text="状态") self.device_tree.heading("session", text="Session ID") self.device_tree.heading("last_test", text="最后测试时间") self.device_tree.column("#0", width=200) self.device_tree.column("status", width=100) self.device_tree.column("session", width=150) self.device_tree.column("last_test", width=150) # 添加滚动条 scrollbar = ttk.Scrollbar(device_frame, orient="vertical", command=self.device_tree.yview) self.device_tree.configure(yscrollcommand=scrollbar.set) # 布局 self.device_tree.pack(side="left", fill="both", expand=True, padx=(0, 5)) scrollbar.pack(side="right", fill="y") # 设备操作按钮 btn_frame = ttk.Frame(device_frame) btn_frame.pack(side="right", fill="y", padx=5) ttk.Button(btn_frame, text="刷新设备", command=self.refresh_devices).pack(pady=5) ttk.Button(btn_frame, text="全选", command=self.select_all).pack(pady=5) ttk.Button(btn_frame, text="取消全选", command=self.deselect_all).pack(pady=5) # 测试控制面板 control_frame = ttk.LabelFrame(self, text="测试控制") control_frame.pack(fill="x", padx=15, pady=10) # 测试类型选择 test_type_frame = ttk.Frame(control_frame) test_type_frame.pack(fill="x", pady=5) self.test_type = tk.StringVar(value="full") ttk.Radiobutton(test_type_frame, text="完整测试", variable=self.test_type, value="full").pack(side="left", padx=10) ttk.Radiobutton(test_type_frame, text="单模块测试", variable=self.test_type, value="module").pack(side="left", padx=10) ttk.Radiobutton(test_type_frame, text="重测失败", variable=self.test_type, value="retry").pack(side="left", padx=10) # 测试参数输入 param_frame = ttk.Frame(control_frame) param_frame.pack(fill="x", pady=5) ttk.Label(param_frame, text="模块:").pack(side="left", padx=(5,0)) self.module_var = tk.StringVar() ttk.Entry(param_frame, textvariable=self.module_var, width=25).pack(side="left") ttk.Label(param_frame, text="测试项:").pack(side="left", padx=(10,0)) self.test_case_var = tk.StringVar() ttk.Entry(param_frame, textvariable=self.test_case_var, width=25).pack(side="left") ttk.Label(param_frame, text="重试次数:").pack(side="left", padx=(20,5)) self.retry_count_var = tk.IntVar(value=1) ttk.Spinbox(param_frame, from_=1, to=10, width=5, textvariable=self.retry_count_var).pack(side="left") # 执行按钮 ttk.Button(control_frame, text="执行测试", command=self.execute_tests, width=15).pack(pady=10) # 日志显示区域 log_frame = ttk.LabelFrame(self, text="测试日志") log_frame.pack(fill="both", expand=True, padx=15, pady=10) self.log_text = scrolledtext.ScrolledText( log_frame, wrap="word", font=("Consolas", 10), bg="#1e1e1e", fg="#d4d4d4" ) self.log_text.pack(fill="both", expand=True, padx=5, pady=5) # 状态栏 self.status_var = tk.StringVar(value="就绪") status_bar = ttk.Label(self, textvariable=self.status_var, relief="sunken") status_bar.pack(side="bottom", fill="x") def refresh_devices(self): """刷新设备列表""" self.log("刷新设备列表中...", "info") success, message = self.device_manager.refresh_devices() if not success: self.log(f"刷新设备失败: {message}", "error") return # 清空现有列表 for item in self.device_tree.get_children(): self.device_tree.delete(item) # 添加新设备 for device_id, info in self.device_manager.devices.items(): status = "空闲" if info['status'] == 'idle' else "测试中" session = info['session_id'] or "无" last_test = info['last_test'] or "从未测试" item = self.device_tree.insert( "", "end", text=device_id, values=(status, session, last_test), tags=(info['status'],) ) # 设置标签颜色 self.device_tree.tag_configure( 'idle', background='#d9ead3' # 空闲状态绿色 ) self.device_tree.tag_configure( 'testing', background='#fce5cd' # 测试中黄色 ) self.device_tree.tag_configure( 'completed', background='#c9daf8' # 完成状态蓝色 ) self.log(f"找到 {len(self.device_manager.devices)} 台设备", "success") def periodic_refresh(self): """定时刷新设备状态""" if not self.test_running: self.refresh_devices() self.after(5000, self.periodic_refresh) def select_all(self): """全选设备""" for item in self.device_tree.get_children(): self.device_tree.selection_add(item) def deselect_all(self): """取消全选""" self.device_tree.selection_set([]) def execute_tests(self): """执行测试""" selected_items = self.device_tree.selection() if not selected_items: self.log("错误: 请至少选择一个设备", "error") return # 获取选中的设备ID device_ids = [self.device_tree.item(item, "text") for item in selected_items] # 更新设备状态为测试中 for device_id in device_ids: self.device_manager.update_device_status(device_id, "testing") # 刷新设备列表显示 self.refresh_devices() # 根据测试类型执行 test_type = self.test_type.get() self.test_running = True self.status_var.set("测试执行中...") if test_type == "full": self.run_full_test(device_ids) elif test_type == "module": module = self.module_var.get().strip() if not module: self.log("错误: 请输入测试模块名称", "error") return test_case = self.test_case_var.get().strip() self.run_module_test(device_ids, module, test_case) elif test_type == "retry": retry_count = self.retry_count_var.get() self.retry_failed(device_ids, retry_count) def run_full_test(self, device_ids): """执行完整测试""" for device_id in device_ids: thread = threading.Thread( target=self._run_device_test, args=(device_id, "run cts --shard-count 3"), daemon=True ) self.active_threads[device_id] = thread thread.start() self.log(f"设备 {device_id} 开始完整测试", "info") def run_module_test(self, device_ids, module, test_case=None): """执行模块测试""" command = f"run cts -m {module}" if test_case: command += f" -t {test_case}" for device_id in device_ids: thread = threading.Thread( target=self._run_device_test, args=(device_id, command), daemon=True ) self.active_threads[device_id] = thread thread.start() self.log(f"设备 {device_id} 开始测试模块: {module}", "info") def retry_failed(self, device_ids, retry_count): """重试失败用例""" for device_id in device_ids: # 获取设备的上次Session ID session_id = self.device_manager.devices.get(device_id, {}).get('session_id') if not session_id: self.log(f"设备 {device_id} 无可用Session ID,跳过重试", "warning") continue for i in range(retry_count): thread = threading.Thread( target=self._run_device_test, args=(device_id, f"run retry --retry {session_id}"), daemon=True ) self.active_threads[device_id] = thread thread.start() self.log(f"设备 {device_id} 开始第 {i+1}/{retry_count} 次重试 (Session: {session_id})", "info") # 等待当前重试完成 while thread.is_alive(): time.sleep(1) def _run_device_test(self, device_id, command): """在设备上执行测试命令""" try: # 模拟测试执行过程 self.log(f"设备 {device_id}: 开始执行命令: {command}", "info") # 在实际应用中,这里应替换为真正的测试命令执行 # 例如: subprocess.run(f"adb -s {device_id} shell {command}", ...) # 模拟测试过程 for i in range(1, 11): if not self.test_running: break time.sleep(1) progress = i * 10 self.log(f"设备 {device_id}: 测试进度 {progress}%", "info") # 模拟捕获Session ID session_id = f"{device_id[:4]}-{int(time.time())}" self.device_manager.update_device_status(device_id, "completed", session_id) self.log(f"设备 {device_id} 测试完成! Session ID: {session_id}", "success") except Exception as e: self.log(f"设备 {device_id} 测试错误: {str(e)}", "error") self.device_manager.update_device_status(device_id, "idle") finally: # 从活动线程中移除 if device_id in self.active_threads: del self.active_threads[device_id] # 如果没有活动线程,标记测试完成 if not self.active_threads: self.test_running = False self.status_var.set("测试完成") # 刷新设备状态 self.refresh_devices() def log(self, message, level="info"): """添加带颜色编码的日志到文本框""" tag = level self.log_text.configure(state="normal") if level == "error": self.log_text.insert("end", message + "\n", "error") self.log_text.tag_config("error", foreground="#f48771") elif level == "success": self.log_text.insert("end", message + "\n", "success") self.log_text.tag_config("success", foreground="#6a9955") elif level == "warning": self.log_text.insert("end", message + "\n", "warning") self.log_text.tag_config("warning", foreground="#dcdcaa") else: self.log_text.insert("end", message + "\n", "info") self.log_text.tag_config("info", foreground="#d4d4d4") self.log_text.see("end") self.log_text.configure(state="disabled") self.update_idletasks() if __name__ == "__main__": app = GMSTestAssistant() app.mainloop()
07-30
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值