【Triton 教程】triton_language.zeros_like

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。 更多 Triton 中文文档可访问 →https://triton.hyper.ai/

triton.language.zeros_like(input)

返回 1 个形状和类型与给定张量相同的全零张量。

参数

  • inputTensor)- 输入张量。
可以帮我详细解释一下嘛def omniquant( lm, args, dataloader, act_scales, act_shifts, logger=None, ): logger.info("Starting ...") # move embedding layer and first layer to target device model = lm.model dev = lm.device use_cache = model.config.use_cache model.config.use_cache = False is_llama = False if "llama" in args.net.lower(): is_llama = True layers = model.model.layers model.model.embed_tokens = model.model.embed_tokens.to(dev) model.model.norm = model.model.norm.to(dev) DecoderLayer = QuantLlamaDecoderLayer pairs = { "q_proj":"qkv", "o_proj":"out", "up_proj":"fc1" } layer_name_prefix = "model.layers" elif "opt" in args.net.lower(): layers = model.model.decoder.layers model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: model.model.decoder.project_out = model.model.decoder.project_out.to(dev) if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: model.model.decoder.project_in = model.model.decoder.project_in.to(dev) DecoderLayer = QuantOPTDecoderLayer pairs = { "q_proj":"qkv", "out_proj":"out", "fc1":"fc1" } layer_name_prefix = "model.decoder.layers" elif "falcon" in args.net.lower(): layers = model.transformer.h model.transformer.word_embeddings.to(dev) model.transformer.ln_f.to(dev) model.lm_head.to(dev) DecoderLayer = QuantFalconDecoderLayer layer_name_prefix = "model.transformer.h" elif 'mixtral' in args.net.lower(): is_llama = True # same to llama except ffn layers = model.model.layers model.model.embed_tokens = model.model.embed_tokens.to(dev) model.model.norm = model.model.norm.to(dev) layer_name_prefix = "model.layers" else: raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now") layers[0] = layers[0].to(dev) if args.deactive_amp and args.epochs>0: dtype = torch.float traincast = nullcontext else: dtype = torch.float16 traincast = torch.cuda.amp.autocast inps = torch.zeros( (args.nsamples, lm.seqlen, model.config.hidden_size), dtype=dtype, device=dev ) cache = {"i": 0} # catch the first layer input class Catcher(nn.Module): def __init__(self, module): super().__init__() self.module = module self.is_llama = False def forward(self, inp, **kwargs): inps[cache["i"]] = inp cache["i"] += 1 cache["attention_mask"] = kwargs["attention_mask"] if self.is_llama: cache["position_ids"] = kwargs["position_ids"] raise ValueError layers[0] = Catcher(layers[0]) layers[0].is_llama = is_llama with torch.no_grad(): for batch in dataloader: if cache["i"] >= args.nsamples: break try: model(batch[0].to(dev)) except ValueError: pass # move embedding layer and first layer to cpu layers[0] = layers[0].module layers[0] = layers[0].cpu() if "llama" in args.net.lower() or "mixtral" in args.net.lower(): model.model.embed_tokens = model.model.embed_tokens.cpu() model.model.norm = model.model.norm.cpu() elif "opt" in args.net.lower(): model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: model.model.decoder.project_out = model.model.decoder.project_out.cpu() if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: model.model.decoder.project_in = model.model.decoder.project_in.cpu() elif 'falcon' in args.model: model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() else: raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now") torch.cuda.empty_cache() # same input of first layer for fp model and quant model quant_inps = inps fp_inps = copy.deepcopy(inps) # take output of fp model as input fp_inps_2 = copy.deepcopy(inps) if args.aug_loss else None # take output of quantization model as input attention_mask = cache["attention_mask"] if attention_mask is not None: attention_mask_batch = attention_mask.repeat(args.batch_size,1,1,1) if args.deactive_amp else attention_mask.repeat(args.batch_size,1,1,1).float() else: logger.info( "No attention mask caught from the first layer." " Seems that model's attention works without a mask." ) attention_mask_batch = None loss_func = torch.nn.MSELoss() if is_llama: position_ids = cache["position_ids"] else: position_ids = None if args.resume: omni_parameters = torch.load(args.resume) else: omni_parameters = {} for i in range(len(layers)): logger.info(f"=== Start quantize layer {i} ===") layer = layers[i].to(dev) if "mixtral" in args.net.lower(): # for mixtral, we only leverage lwc, which can be achieve by simply replace Linear with QuantLinear qlayer = copy.deepcopy(layer) for name, module in qlayer.named_modules(): if isinstance(module,torch.nn.Linear) and not "gate" in name: # do not quantize gate quantlinear = QuantLinear(module, args.weight_quant_params, args.act_quant_params) add_new_module(name, qlayer, quantlinear) else: qlayer = DecoderLayer(lm.model.config, layer, args) qlayer = qlayer.to(dev) # obtain output of full-precision model set_quant_state(qlayer, weight_quant=False, act_quant=False) if args.epochs > 0: with torch.no_grad(): with torch.cuda.amp.autocast(): for j in range(args.nsamples): fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] if args.aug_loss: fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] # init smooth parameters set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward qlayer.let = args.let use_shift = True if is_llama or args.abits == 16: use_shift = False # deactivate channel-wise shifting for llama model and weight-only quantization if args.let: # init channel-wise scaling and shift qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.ones(layer.self_attn.q_proj.out_features,device=dev, dtype=dtype))) for name,module in qlayer.named_modules(): if isinstance(module, QuantLinear): for key in pairs.keys(): if key in name: act = act_scales[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype).clamp(min=1e-5) weight = module.weight.abs().max(dim=0)[0].clamp(min=1e-5) scale = (act.pow(args.alpha)/weight.pow(1-args.alpha)).clamp(min=1e-5) if use_shift and not is_llama: shift = act_shifts[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype) else: shift = torch.zeros_like(scale) qlayer.register_parameter(f"{pairs[key]}_smooth_shift",torch.nn.Parameter(shift)) qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(scale)) if args.resume: qlayer.load_state_dict(omni_parameters[i], strict=False) if args.epochs > 0: with torch.no_grad(): qlayer.float() # required for AMP training # create optimizer optimizer = torch.optim.AdamW( [{"params":let_parameters(qlayer, use_shift),"lr":args.let_lr}, {"params":lwc_parameters(qlayer),"lr":args.lwc_lr}],weight_decay=args.wd) loss_scaler = utils.NativeScalerWithGradNormCount() for epochs in range(args.epochs): loss_list = [] norm_list = [] for j in range(args.nsamples//args.batch_size): index = j * args.batch_size # obtain output of quantization model with traincast(): smooth_and_quant_temporary(qlayer, args, is_llama) quant_out = qlayer(quant_inps[index:index+args.batch_size,], attention_mask=attention_mask_batch,position_ids=position_ids)[0] loss = loss_func(fp_inps[index:index+args.batch_size,], quant_out) if args.aug_loss: loss += loss_func(fp_inps_2[index:index+args.batch_size,], quant_out) if not math.isfinite(loss.item()): logger.info("Loss is NAN, stopping training") pdb.set_trace() loss_list.append(loss.detach().cpu()) optimizer.zero_grad() norm = loss_scaler(loss, optimizer,parameters= get_omni_parameters(qlayer, use_shift)).cpu() norm_list.append(norm.data) loss_mean = torch.stack(loss_list).mean() norm_mean = torch.stack(norm_list).mean() logger.info(f"layer {i} iter {epochs} loss:{loss_mean} norm:{norm_mean} max memory_allocated {torch.cuda.max_memory_allocated(lm._device) / 1024**2} ") clear_temp_variable(qlayer) del optimizer qlayer.half() # real smooth and quantization smooth_and_quant_inplace(qlayer, args, is_llama) if args.epochs>0: # update input of quantization model with torch.no_grad(): # with torch.cuda.amp.autocast(): with traincast(): for j in range(args.nsamples): quant_inps[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] register_scales_and_zeros(qlayer) layers[i] = qlayer.to("cpu") omni_parameters[i] = omni_state_dict(qlayer) torch.save(omni_parameters, os.path.join(args.output_dir, f"omni_parameters.pth")) else: register_scales_and_zeros(qlayer) layers[i] = qlayer.to("cpu") if args.real_quant: assert args.wbits in [2,3,4] and args.abits >= 16 # only support weight-only quantization named_linears = get_named_linears(qlayer) for name, module in named_linears.items(): scales = module.weight_quantizer.scales zeros = module.weight_quantizer.zeros group_size = module.weight_quantizer.group_size dim0 = module.weight.shape[0] scales = scales.view(dim0,-1) zeros = zeros.view(dim0,-1) if args.wbits == 3: q_linear = qlinear_cuda.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None) else: q_linear = qlinear_triton.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None) q_linear.pack(module.cpu(), scales.float().cpu(), zeros.float().cpu()) add_new_module(name, qlayer, q_linear) print(f"pack quantized {name} finished") del module del layer torch.cuda.empty_cache() del inps del quant_inps del fp_inps del fp_inps_2 torch.cuda.empty_cache() gc.collect() model.config.use_cache = use_cache return model
08-20
import numpy as np import cv2 from numpy.array_api import uint8 from skimage.measure import label, regionprops from openvino.runtime import Core from PIL import Image import torch import openvino as ov import matplotlib.pyplot as plt from sympy.codegen.ast import int32 from triton.language import dtype class YiAtrium(): def __init__(self, device="CPU"): self.core = Core() self.model = self.core.read_model("/Work/zhangxin/ultralytics/runs/segment/train/weights/best_openvino_model/best.xml") self.compiled_model = self.core.compile_model(self.model, device) # 修改1:使用Opencv库读取数据,适配原有的代码。 # 修改2:输入cvmat图像数据 def preprocess(self, image): # img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE) # img = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img, (256,256)) img_ndarray = img_resized if img_ndarray.ndim == 2: img_ndarray = img_ndarray[np.newaxis, ...] img_ndarray = img_ndarray[np.newaxis, ...] img_ndarray = img_ndarray / 255.0 tensor = torch.from_numpy(img_ndarray) img_tensor = tensor.permute(2, 0, 1).float().contiguous() # 或 tensor = tensor.permute(2, 0, 1).contiguous() img_tensor = img_tensor.unsqueeze(0) # 形状变为:(1, 3, 256, 256) #img_ndarray = img_ndarray / 255.0 #img_tensor = torch.as_tensor(img_ndarray.copy()).float().contiguous() return img_tensor @staticmethod def keep_largest_region(mask): if mask.max() == 0: return mask labeled_mask = label(mask, connectivity=1) regions = regionprops(labeled_mask) if not regions: return mask largest_region = max(regions, key=lambda r: r.area) result_mask = np.zeros_like(mask) result_mask[labeled_mask == largest_region.label] = 1 return result_mask @staticmethod def fill_hole(image): """填充图像中的空洞""" # 确保是二值图像 if len(np.unique(image)) > 2: _, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) src = image.copy() mask = np.zeros([src.shape[0] + 2, src.shape[1] + 2], np.uint8) # 找到背景点 isbreak = False for i in range(src.shape[0]): for j in range(src.shape[1]): if src[i, j] == 0: seedpoint = (j, i) # 注意坐标顺序 (x,y) isbreak = True break if isbreak: break cv2.floodFill(src, mask, seedpoint, 255) img_floofill_inv = cv2.bitwise_not(src) im_out = image | img_floofill_inv return im_out @staticmethod def get_edge_points(mask, scalex, scaley): contours, _ = cv2.findContours( mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE ) edge_points = [] for contour in contours: epsilon = 0.005 * cv2.arcLength(contour, True) approx = cv2.approxPolyDP(contour, epsilon, True) for point in approx: x, y = point[0] up_x = int(x * scaley) up_y = int(y * scalex) edge_points.append((up_x, up_y)) final_result = {"detection": {}, "segmentation": {}} point_dict = {"1-LVIntima_pylogon": edge_points.copy()} final_result["segmentation"] = point_dict return final_result # 原始:输入的是图像的路径 # def Yi_Segment(self, image_path): # input_tensor = self.preprocess(image_path) # predict_tensor = self.compiled_model(input_tensor) # output_ndarray = predict_tensor[0].argmax(1).squeeze(0) # prediction = output_ndarray.astype(np.uint8) # prediction = prediction * 255 # output_mask = self.fill_hole(self.keep_largest_region(prediction)) # final_output = self.get_edge_points(output_mask) # return final_output # 修改:输入cvmat图像数据 def Yi_Segment(self, image): input_tensor = self.preprocess(image) predict_tensor = self.compiled_model(input_tensor) output_ndarray = predict_tensor[0].argmax(1).squeeze(0) prediction = output_ndarray.astype(np.uint8) prediction = prediction * 255 output_mask = self.fill_hole(self.keep_largest_region(prediction)) img_h, img_w = image.shape[:2] scalex = img_h / 256.0 scaley = img_w / 256.0 final_output = self.get_edge_points(output_mask, scalex, scaley) #final_output = self.get_edge_points(final_output) return final_output if __name__ == "__main__": image_path = "/Work/zhangxin/ultralytics/runs/test/Heart_A2C_0000003_20240416_Comen_cropped_105.bmp" # segmenter = YiAtrium() image = cv2.imread(image_path) final_result = segmenter.Yi_Segment(image) img = cv2.imread(image_path) new_img = np.zeros((image.shape[0], image.shape[1]), dtype=uint8) cv2.drawContours(new_img, [np.asarray(final_result["segmentation"]["1-LVIntima_pylogon"])], -1, [255, 0], thickness=cv2.FILLED) cv2.imshow("1", new_img) cv2.imshow(("2"),img) cv2.waitKey(0) 帮我总结此代码实现的内容
09-03
import torch import triton import triton.language as tl @triton.jit def fused_bmm_softmax_kernel( # 输入/输出指针 a_ptr, b_ptr, c_ptr, # 张量维度 B, M, N, K, # 张量A步长 (batch, row, col) stride_ab, stride_am, stride_ak, # 张量B步长 (batch, col, col) stride_bb, stride_bk, stride_bn, # 输出C步长 (batch, row, col) stride_cb, stride_cm, stride_cn, # 分块参数 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): # 程序ID对应输出矩阵中的位置 (batch, row) batch_id = tl.program_id(0) row_id = tl.program_id(1) # 初始化行最大值为负无穷 row_max = tl.full((), float('-inf'), dtype=tl.float32) # Pass 1: 计算当前行的最大值 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 为当前列块预加载A的部分数据 a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn # 累加部分矩阵乘法结果 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 更新行最大值 max_val = tl.max(accumulator) row_max = tl.maximum(row_max, max_val) # Pass 2: 计算指数和 (exp_sum) exp_sum = tl.zeros((), dtype=tl.float32) for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 重新加载数据(与Pass 1类似) a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算稳定指数值并累加 exp_vals = tl.exp(accumulator - row_max) exp_sum += tl.sum(exp_vals) # Pass 3: 计算归一化结果并写入 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算最终Softmax结果 exp_vals = tl.exp(accumulator - row_max) softmax_output = exp_vals / exp_sum # 写入输出 c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn tl.store(c_ptrs, softmax_output) 封装函数供PyTorch调用 def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor): B, M, K = a.shape B, K, N = b.shape c = torch.empty((B, M, N), device=a.device, dtype=a.dtype) # 配置分块大小 (需根据硬件调整) BLOCK_SIZE_M = 1 # 每个程序实例处理1行 BLOCK_SIZE_N = 64 # 列方向分块大小 BLOCK_SIZE_K = 32 # K维度分块大小 grid = (B, M) # 每个批次每行一个程序实例 fused_bmm_softmax_kernel[grid]( a, b, c, B, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), b.stride(2), c.stride(0), c.stride(1), c.stride(2), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K ) return c把上面的融合kernel跟据下面的代码的思路写一个配套的代码(完全与融合kernel配套,只是代码思路类似而已):import os import sys import random import itertools import csv from triton import Config import torch import triton import torch_mlu import torch.nn as nn import triton.language as tl os.environ[“TRITON_PRINT_AUTOTUNING”] = “1” def get_config(): row_block = list(range(1, 33)) num_stages = [1, 2, 3, 4] configs = list(itertools.product(row_block, num_stages)) random.shuffle(configs) return configs @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr, ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr): # The rows of the softmax are independent, so we parallelize across rows. pid = tl.program_id(0) ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK) outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK for inner_row in range(outer_raw_idx, ub, ROW_BLOCK): row_offsets = tl.arange(0, ROW_BLOCK) row_mask = (inner_row + row_offsets < ub)[:, None] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] input_ptrs = row_start_ptr[:, None] + col_offsets # Load rows into SRAM row = tl.load(input_ptrs, mask=row_mask).to(tl.float32) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=1)[:, None] # Note that exponentiation in Triton is fast but approximate numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = numerator / denominator softmax_output = softmax_output.to(tl.float16) # Write back output to DRAM output_row_start_ptr = output_ptr + (inner_row + row_offsets) * output_row_stride output_ptrs = output_row_start_ptr[:, None] + col_offsets tl.store(output_ptrs, softmax_output, mask=row_mask) def triton_softmax_perf(configs, x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols # Allocate output y = torch.empty_like(x) # Enqueue kernel. We split all rows into 48 parts evenly. OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48) grid = lambda META: (48, 1, 1) softmax_kernel[(grid)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, OUTER_ROW_BLOCK=OUTER_ROW_BLOCK, ROW_BLOCK=configs[0], num_stages=configs[1]) return y def benchmark_softmax(configs, M, K, warmup, repeat): # 创建输入数据 a = torch.randn(M, K, dtype=torch.float16, device=‘mlu’).contiguous() trainset, count, maxsampled = [], 0, 20 for i in range(len(configs)): try: # filter invalid settings c_triton = triton_softmax_perf(configs[i], a) except: continue for _ in range(warmup): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() start_event = torch.mlu.Event(enable_timing=True) end_event = torch.mlu.Event(enable_timing=True) start_event.record() for _ in range(repeat): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() end_event.record() torch.mlu.synchronize() trainset.append(list(configs[i])) trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000) print('train data {}: {}'.format(count, trainset[count])) count += 1 if count >= maxsampled: break with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f: writer = csv.writer(f) writer.writerows(trainset) def main(): micro_shapes = [ (80, 768), (604, 768), (1216, 768), (1968, 768), ] configs = get_config() for M, K in micro_shapes: print('Processing input {} {}'.format(M, K)) benchmark_softmax(configs, M, K, warmup=10, repeat=20) if name==“main”: main()
06-12
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值