tf.scatter_update和tf.batch_scatter_update

本文详细介绍并演示了使用TensorFlow中的tf.scatter_update和tf.batch_scatter_update函数更新张量的具体方法。通过实例展示了如何利用indices指定ref中被替换的对象,并解释了updates的shape与ref的shape之间的关系。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.scatter_update

函数定义:

tf.scatter_update(
    ref,
    indices,
    updates,
    use_locking=True,
    name=None
)

需要说明的是,updates.shape = [*indices.shape, *ref.shape[1:]], upadtes的shape不一定与ref的shape相等。

测试:

a = tf.Variable([[1, 2, 3, 4], [5, 6, 7, 8]])
indices = [[0, 1], [1, 0]]
updates = [[[1, 1, 1, 1], [2, 3, 4, 5]], [[2, 2, 2, 2], [3, 3, 3, 3]]]
b = tf.scatter_update(a, indices, updates)

sess = tf.InteractiveSession()
print(sess.run([b]))

结果:

[array([[3, 3, 3, 3],
        [2, 2, 2, 2]])]

结果说明:

重点是解读indices的含义

indices的值指定ref中被替换的对象,以上面测试为例,indices中的0,1分别指定a中的a[0]、a[1]将被替换。

indices的值对应的index指定ref中的相应替代值为update[index],以上为例,indices[0][0]为1,则a[1]将被替换为updates[0][0]。

此外indices中重复出现的值将被多次替换,至于结果是不确定的。


tf.batch_scatter_update

函数定义:

与tf.scatter_update相同

测试:

d = tf.Variable([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
indices = [[1, 1], [1, 0]]
updates = [[[1, 1], [2, 2]], [[3, 3], [4, 4]]]
e = tf.batch_scatter_update(d, indices, updates)
sess.run(tf.global_variables_initializer())
print(sess.run([e]))

结果:

[array([[[0, 1],
         [2, 2]],
 
        [[4, 4],
         [3, 3]]])]

结果说明:

tf.batch_scatter_update与tf.scatter_update类似,只是在进行值替换时,tf.scatter_update中ref替换对象由indices的值指定,而在tf.batch_scatter_update中由indices的值和对应的index[:-1]共同指定。

以上面为例 :indices[0][1]为1,则d[0][1]]替换为updates[0][1],其中d[0][1]中的1是indices[0][1]的值。

class MLA(layers.Layer): def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim self.n_heads = args.n_heads self.q_lora_rank = args.q_lora_rank self.kv_lora_rank = args.kv_lora_rank self.qk_nope_head_dim = args.qk_nope_head_dim self.qk_rope_head_dim = args.qk_rope_head_dim self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim self.v_head_dim = args.v_head_dim # 初始化投影层 if self.q_lora_rank == 0: self.wq = layers.Dense(self.n_heads * self.qk_head_dim) else: self.wq_a = layers.Dense(self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) self.wq_b = layers.Dense(self.n_heads * self.qk_head_dim) self.wkv_a = layers.Dense(self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) self.wkv_b = layers.Dense(self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wo = layers.Dense(self.dim) self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 self.softmax_scale *= mscale * mscale # 初始化缓存 self.k_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.qk_head_dim)),trainable=False) self.v_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.v_head_dim)),trainable=False) def call(self, x, start_pos, freqs_cis, mask=None): bsz = tf.shape(x)[0] seqlen = tf.shape(x)[1] end_pos = start_pos + seqlen # 查询投影 if self.q_lora_rank == 0: q = self.wq(x) else: q = self.wq_b(self.q_norm(self.wq_a(x))) q = tf.reshape(q, [bsz, seqlen, self.n_heads, self.qk_head_dim]) q_nope, q_pe = tf.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) q_pe = apply_rotary_emb(q_pe, freqs_cis) # 键值投影 kv = self.wkv_a(x) kv, k_pe = tf.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) k_pe = apply_rotary_emb(tf.expand_dims(k_pe, 2), freqs_cis) kv = self.wkv_b(self.kv_norm(kv)) kv = tf.reshape(kv, [bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim]) k_nope, v = tf.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) k = tf.concat([k_nope, tf.tile(k_pe, [1, 1, self.n_heads, 1])], axis=-1) # 更新缓存 updates_range = tf.range(start_pos, end_pos) self.k_cache.assign(tf.tensor_scatter_nd_update(self.k_cache,updates_range[:, None],k)) self.v_cache.assign(tf.tensor_scatter_nd_update(self.v_cache,updates_range[:, None],v)) # 注意力计算 q = tf.concat([q_nope, q_pe], axis=-1) scores = tf.einsum("bshd,bthd->bhst", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale if mask is not None: scores += mask[:, None, :, :] scores = tf.nn.softmax(scores, axis=-1) x = tf.einsum("bhst,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) return self.wo(tf.reshape(x, [bsz, seqlen, -1])) 将缓存去掉
03-13
def train_model(self): """训练PINNs模型(自动切分验证集)""" if self.train_df is None: messagebox.showwarning("警告", "请先选择训练集文件") return try: self.status_var.set("正在预处理数据...") self.root.update() # 从训练集中切分80%训练子集20%验证子集(时间顺序切分) split_ratio = 0.8 split_idx = int(len(self.train_df) * split_ratio) train_subset = self.train_df.iloc[:split_idx] valid_subset = self.train_df.iloc[split_idx:] # 检查数据量是否足够 if len(train_subset) < 2 or len(valid_subset) < 2: messagebox.showerror("数据错误", "训练集数据量不足(至少需要2个时间步)") return # 数据预处理(训练子集拟合scaler,验证子集用相同scaler) train_subset_scaled = self.scaler.fit_transform(train_subset[['水位']]) valid_subset_scaled = self.scaler.transform(valid_subset[['水位']]) # 准备训练数据 t_train = train_subset['days'].values[1:].reshape(-1, 1).astype(np.float32) h_train = train_subset_scaled[:-1].astype(np.float32) h_next_train = train_subset_scaled[1:].astype(np.float32) # 准备验证数据 t_valid = valid_subset['days'].values[1:].reshape(-1, 1).astype(np.float32) h_valid = valid_subset_scaled[:-1].astype(np.float32) h_next_valid = valid_subset_scaled[1:].astype(np.float32) # 创建模型优化器 self.model = PINNModel( num_layers=self.num_layers_var.get(), hidden_units=self.hidden_units_var.get() ) optimizer = Adam(learning_rate=0.001) # 构建训练/验证数据集 train_dataset = tf.data.Dataset.from_tensor_slices(((t_train, h_train), h_next_train)) train_dataset = train_dataset.shuffle(buffer_size=1024).batch(32) valid_dataset = tf.data.Dataset.from_tensor_slices(((t_valid, h_valid), h_next_valid)) valid_dataset = valid_dataset.batch(32) # 验证集无需shuffle # 损失记录 train_data_loss_history = [] physics_loss_history = [] valid_data_loss_history = [] start_time = time.time() # 自定义训练循环 for epoch in range(self.epochs_var.get()): # 训练阶段 epoch_train_data_loss = [] epoch_physics_loss = [] for step, ((t_batch, h_batch), h_next_batch) in enumerate(train_dataset): with tf.GradientTape() as tape: h_pred = self.model([t_batch, h_batch]) data_loss = tf.reduce_mean(tf.square(h_next_batch - h_pred)) physics_loss = self.model.physics_loss(t_batch, h_batch) loss = data_loss + self.physics_weight_var.get() * physics_loss grads = tape.gradient(loss, self.model.trainable_variables) optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) epoch_train_data_loss.append(data_loss.numpy()) epoch_physics_loss.append(physics_loss.numpy()) # 验证阶段 epoch_valid_data_loss = [] for ((t_v_batch, h_v_batch), h_v_next_batch) in valid_dataset: h_v_pred = self.model([t_v_batch, h_v_batch]) valid_data_loss = tf.reduce_mean(tf.square(h_v_next_batch - h_v_pred)) epoch_valid_data_loss.append(valid_data_loss.numpy()) # 计算平均损失 avg_train_data_loss = np.mean(epoch_train_data_loss) avg_physics_loss = np.mean(epoch_physics_loss) avg_valid_data_loss = np.mean(epoch_valid_data_loss) # 记录损失 train_data_loss_history.append(avg_train_data_loss) physics_loss_history.append(avg_physics_loss) valid_data_loss_history.append(avg_valid_data_loss) # 更新状态 if epoch % 10 == 0: k_value = self.model.k.numpy() elapsed = time.time() - start_time self.status_var.set( f"训练中 | 轮次: {epoch + 1}/{self.epochs_var.get()} | " f"训练数据损失: {avg_train_data_loss:.4f} | " f"物理损失: {avg_physics_loss:.4f} | " f"验证数据损失: {avg_valid_data_loss:.4f} | " f"k: {k_value:.6f} | 时间: {elapsed:.1f}秒" ) self.root.update() # 绘制损失曲线 self.loss_ax.clear() epochs_range = range(1, len(train_data_loss_history) + 1) self.loss_ax.plot(epochs_range, train_data_loss_history, 'b-', label='训练数据损失') self.loss_ax.plot(epochs_range, physics_loss_history, 'r--', label='物理损失') self.loss_ax.plot(epochs_range, valid_data_loss_history, 'g-.', label='验证数据损失') self.loss_ax.set_title('PINNs训练与验证损失') self.loss_ax.set_xlabel('轮次') self.loss_ax.set_ylabel('损失', rotation=0) self.loss_ax.legend() self.loss_ax.grid(True) self.loss_ax.set_yscale('log') self.loss_canvas.draw() # 训练完成提示 elapsed = time.time() - start_time self.status_var.set( f"训练完成 | 总轮次: {self.epochs_var.get()} | " f"最终训练数据损失: {train_data_loss_history[-1]:.4f} | " f"最终物理损失: {physics_loss_history[-1]:.4f} | " f"最终验证数据损失: {valid_data_loss_history[-1]:.4f} | " f"总时间: {elapsed:.1f}秒" ) messagebox.showinfo("训练完成", "PINNs模型训练成功完成!") except Exception as e: messagebox.showerror("训练错误", f"模型训练失败:\n{str(e)}") self.status_var.set("训练失败") 帮我添加早停机制
最新发布
07-20
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值