【Triton 教程】triton_language.full

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

triton.language.full(shape, value, dtype)

返回一个张量,该张量填充了指定 shape 和 dtype 的标量值。

参数**:**

  • shape (tuple of ints) - 新数组的形状,例如 (8, 16) 或 (8,)。
  • value (scalar) - 用于填充数组的标量值。
  • dtype (tl.dtype) - 新数组的数据类型,例如 tl.float16
Current thread 0x00007fcd17efa000 (most recent call first): File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 223 in __init__ File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1069 in call_JitFunction File "/usr/local/lib/python3.11/dist-packages/triton/language/core.py", line 1987 in make_combine_region File "/usr/local/lib/python3.11/dist-packages/triton/language/semantic.py", line 1455 in reduction File "/usr/local/lib/python3.11/dist-packages/triton/language/core.py", line 2003 in reduce File "/usr/local/lib/python3.11/dist-packages/triton/language/core.py", line 35 in wrapper File "/usr/local/lib/python3.11/dist-packages/triton/language/core.py", line 1976 in reduce File "/usr/local/lib/python3.11/dist-packages/triton/language/core.py", line 35 in wrapper File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1116 in visit_Call File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 369 in visit_Return File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 351 in visit_compound_statement File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 443 in visit_FunctionDef File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/lib/python3.11/ast.py", line 418 in generic_visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 359 in visit_Module File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1074 in call_JitFunction File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1109 in visit_Call File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 545 in visit_BinOp File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 496 in visit_Assign File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 516 in visit_AugAssign File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 351 in visit_compound_statement File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 964 in visit_For File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 351 in visit_compound_statement File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 443 in visit_FunctionDef File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/lib/python3.11/ast.py", line 418 in generic_visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 359 in visit_Module File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1074 in call_JitFunction File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1109 in visit_Call File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 496 in visit_Assign File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 351 in visit_compound_statement File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 443 in visit_FunctionDef File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/lib/python3.11/ast.py", line 418 in generic_visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 359 in visit_Module File "/usr/lib/python3.11/ast.py", line 410 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1204 in visit File "/usr/local/lib/python3.11/dist-packages/triton/compiler/code_generator.py", line 1297 in ast_to_ttir File "/usr/local/lib/python3.11/dist-packages/jax_triton/triton_lib.py", line 437 in get_or_create_triton_kernel File "/usr/local/lib/python3.11/dist-packages/jax_triton/triton_lib.py", line 626 in triton_kernel_call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2056 in f_lowered File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1610 in lower_jaxpr_to_fun File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2074 in _lower_jaxpr_to_fun_cached File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2115 in call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2130 in core_call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/loops.py", line 1782 in _while_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2056 in f_lowered File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1610 in lower_jaxpr_to_fun File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2074 in _lower_jaxpr_to_fun_cached File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2115 in call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2130 in core_call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/lax/control_flow/loops.py", line 1782 in _while_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2056 in f_lowered File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1610 in lower_jaxpr_to_fun File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2074 in _lower_jaxpr_to_fun_cached File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2115 in call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 2130 in core_call_lowering File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1941 in lower_per_platform File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/mlir.py", line 1825 in jaxpr_subcomp ... Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, rdkit.rdBase, rdkit.DataStructs.cDataStructs, rdkit.Chem.rdchem, rdkit.Geometry.rdGeometry, rdkit.Chem.rdinchi, rdkit.Chem.rdCIPLabeler, rdkit.Chem.rdmolfiles, rdkit.Chem.rdmolops, rdkit.Chem.rdMolInterchange, rdkit.Chem.rdCoordGen, zstandard.backend_c, rdkit.ForceField.rdForceField, rdkit.Chem.rdChemicalFeatures, rdkit.Chem.rdMolChemicalFeatures, rdkit.Chem.rdDistGeom, rdkit.Chem.rdChemReactions, rdkit.Chem.rdDepictor, rdkit.Chem.rdFingerprintGenerator, rdkit.Chem.rdForceFieldHelpers, rdkit.Chem.rdMolAlign, rdkit.Chem.rdMolDescriptors, rdkit.Chem.rdMolEnumerator, rdkit.Chem.rdMolTransforms, rdkit.Chem.rdPartialCharges, rdkit.Chem.rdqueries, rdkit.Chem.rdReducedGraphs, rdkit.Chem.rdShapeHelpers, rdkit.Chem.rdSLNParse, jaxlib.cpu_feature_guard, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._generator, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand, scipy._lib._ccallback_c, _cyutility, scipy._cyutility, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_schur_sqrtm, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.special._ufuncs_cxx, scipy.special._ellip_harm_2, scipy.special._special_ufuncs, scipy.special._gufuncs, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.spatial._hausdorff, scipy.spatial._distance_wrap, scipy.spatial.transform._rotation, scipy.spatial.transform._rigid_transform (total: 72) Segmentation fault (core dumped)
09-03
bash: docker: command not found [root@ac6b15bb1f77 mas]# python -m vllm.entrypoints.openai.api_server \ > --model /models/z50051264/summary/Qwen2.5-7B-awq/ \ > --max-num-seqs=256 \ > --max-model-len=4096 \ > --max-num-batched-tokens=4096 \ > --tensor-parallel-size=1 \ > --block-size=128 \ > --host=0.0.0.0 \ > --port=8080 \ > --gpu-memory-utilization=0.9 \ > --trust-remote-code \ > --served-model-name=zzz \ > --quantization awq INFO 07-25 03:09:14 [__init__.py:39] Available plugins for group vllm.platform_plugins: INFO 07-25 03:09:14 [__init__.py:41] - ascend -> vllm_ascend:register INFO 07-25 03:09:14 [__init__.py:44] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 07-25 03:09:14 [__init__.py:235] Platform plugin ascend is activated WARNING 07-25 03:09:15 [_custom_ops.py:20] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'") INFO 07-25 03:09:18 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available. WARNING 07-25 03:09:19 [registry.py:413] Model architecture DeepSeekMTPModel is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP. WARNING 07-25 03:09:19 [registry.py:413] Model architecture Qwen2VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration. WARNING 07-25 03:09:19 [registry.py:413] Model architecture Qwen2_5_VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration. WARNING 07-25 03:09:19 [registry.py:413] Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM. WARNING 07-25 03:09:19 [registry.py:413] Model architecture DeepseekV3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM. WARNING 07-25 03:09:19 [registry.py:413] Model architecture Qwen3MoeForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM. INFO 07-25 03:09:20 [api_server.py:1395] vLLM API server version 0.9.2 INFO 07-25 03:09:20 [cli_args.py:325] non-default args: {'host': '0.0.0.0', 'port': 8080, 'model': '/models/z50051264/summary/Qwen2.5-7B-awq/', 'trust_remote_code': True, 'max_model_len': 4096, 'quantization': 'awq', 'served_model_name': ['zzz'], 'block_size': 128, 'max_num_batched_tokens': 4096, 'max_num_seqs': 256} INFO 07-25 03:09:34 [config.py:841] This model supports multiple tasks: {'generate', 'classify', 'embed', 'reward'}. Defaulting to 'generate'. INFO 07-25 03:09:34 [config.py:1472] Using max model len 4096 WARNING 07-25 03:09:35 [config.py:960] ascend quantization is not fully optimized yet. The speed can be slower than non-quantized models. INFO 07-25 03:09:35 [config.py:2285] Chunked prefill is enabled with max_num_batched_tokens=4096. INFO 07-25 03:09:35 [platform.py:174] PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode INFO 07-25 03:09:35 [utils.py:321] Calculated maximum supported batch sizes for ACL graph: 66 INFO 07-25 03:09:35 [utils.py:336] Adjusted ACL graph batch sizes for Qwen2ForCausalLM model (layers: 28): 67 → 66 sizes INFO 07-25 03:09:45 [__init__.py:39] Available plugins for group vllm.platform_plugins: INFO 07-25 03:09:45 [__init__.py:41] - ascend -> vllm_ascend:register INFO 07-25 03:09:45 [__init__.py:44] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 07-25 03:09:45 [__init__.py:235] Platform plugin ascend is activated WARNING 07-25 03:09:46 [_custom_ops.py:20] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'") INFO 07-25 03:09:50 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available. INFO 07-25 03:09:50 [core.py:526] Waiting for init message from front-end. WARNING 07-25 03:09:50 [registry.py:413] Model architecture DeepSeekMTPModel is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP. WARNING 07-25 03:09:50 [registry.py:413] Model architecture Qwen2VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration. WARNING 07-25 03:09:50 [registry.py:413] Model architecture Qwen2_5_VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration. WARNING 07-25 03:09:50 [registry.py:413] Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM. WARNING 07-25 03:09:50 [registry.py:413] Model architecture DeepseekV3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM. WARNING 07-25 03:09:50 [registry.py:413] Model architecture Qwen3MoeForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM. INFO 07-25 03:09:50 [core.py:69] Initializing a V1 LLM engine (v0.9.2) with config: model='/models/z50051264/summary/Qwen2.5-7B-awq/', speculative_config=None, tokenizer='/models/z50051264/summary/Qwen2.5-7B-awq/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=ascend, enforce_eager=False, kv_cache_dtype=auto, device_config=npu, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=zzz, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["all"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.unified_ascend_attention_with_output"],"use_inductor":false,"compile_sizes":[],"inductor_compile_config":{},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null} ERROR 07-25 03:09:53 [core.py:586] EngineCore failed to start. ERROR 07-25 03:09:53 [core.py:586] Traceback (most recent call last): ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 577, in run_engine_core ERROR 07-25 03:09:53 [core.py:586] engine_core = EngineCoreProc(*args, **kwargs) ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 404, in __init__ ERROR 07-25 03:09:53 [core.py:586] super().__init__(vllm_config, executor_class, log_stats, ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 75, in __init__ ERROR 07-25 03:09:53 [core.py:586] self.model_executor = executor_class(vllm_config) ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__ ERROR 07-25 03:09:53 [core.py:586] self._init_executor() ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor ERROR 07-25 03:09:53 [core.py:586] self.collective_rpc("init_device") ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc ERROR 07-25 03:09:53 [core.py:586] answer = run_method(self.driver_worker, method, args, kwargs) ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/utils/__init__.py", line 2736, in run_method ERROR 07-25 03:09:53 [core.py:586] return func(*args, **kwargs) ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device ERROR 07-25 03:09:53 [core.py:586] self.worker.init_device() # type: ignore ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm-ascend/vllm_ascend/worker/worker_v1.py", line 132, in init_device ERROR 07-25 03:09:53 [core.py:586] NPUPlatform.set_device(device) ERROR 07-25 03:09:53 [core.py:586] File "/vllm-workspace/vllm-ascend/vllm_ascend/platform.py", line 98, in set_device ERROR 07-25 03:09:53 [core.py:586] torch.npu.set_device(device) ERROR 07-25 03:09:53 [core.py:586] File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch_npu/npu/utils.py", line 80, in set_device ERROR 07-25 03:09:53 [core.py:586] torch_npu._C._npu_setDevice(device_id) ERROR 07-25 03:09:53 [core.py:586] RuntimeError: SetPrecisionMode:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:156 NPU function error: at_npu::native::AclSetCompileopt(aclCompileOpt::ACL_PRECISION_MODE, precision_mode), error code is 500001 ERROR 07-25 03:09:53 [core.py:586] [ERROR] 2025-07-25-03:09:53 (PID:977, Device:0, RankID:-1) ERR00100 PTA call acl api failed ERROR 07-25 03:09:53 [core.py:586] [Error]: The internal ACL of the system is incorrect. ERROR 07-25 03:09:53 [core.py:586] Rectify the fault based on the error information in the ascend log. ERROR 07-25 03:09:53 [core.py:586] EC0010: [PID: 977] 2025-07-25-03:09:53.177.260 Failed to import Python module [AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead..]. ERROR 07-25 03:09:53 [core.py:586] Solution: Check that all required components are properly installed and the specified Python path matches the Python installation directory. (If the path does not match the directory, run set_env.sh in the installation package.) ERROR 07-25 03:09:53 [core.py:586] TraceBack (most recent call last): ERROR 07-25 03:09:53 [core.py:586] AOE Failed to call InitCannKB[FUNC:Initialize][FILE:python_adapter_manager.cc][LINE:47] ERROR 07-25 03:09:53 [core.py:586] Failed to initialize TeConfigInfo. ERROR 07-25 03:09:53 [core.py:586] [GraphOpt][InitializeInner][InitTbeFunc] Failed to init tbe.[FUNC:InitializeTeFusion][FILE:tbe_op_store_adapter.cc][LINE:1889] ERROR 07-25 03:09:53 [core.py:586] [GraphOpt][InitializeInner][InitTeFusion]: Failed to initialize TeFusion.[FUNC:InitializeInner][FILE:tbe_op_store_adapter.cc][LINE:1856] ERROR 07-25 03:09:53 [core.py:586] [SubGraphOpt][PreCompileOp][InitAdapter] InitializeAdapter adapter [tbe_op_adapter] failed! Ret [4294967295][FUNC:InitializeAdapter][FILE:op_store_adapter_manager.cc][LINE:79] ERROR 07-25 03:09:53 [core.py:586] [SubGraphOpt][PreCompileOp][Init] Initialize op store adapter failed, OpsStoreName[tbe-custom].[FUNC:Initialize][FILE:op_store_adapter_manager.cc][LINE:120] ERROR 07-25 03:09:53 [core.py:586] [FusionMngr][Init] Op store adapter manager init failed.[FUNC:Initialize][FILE:fusion_manager.cc][LINE:115] ERROR 07-25 03:09:53 [core.py:586] PluginManager InvokeAll failed.[FUNC:Initialize][FILE:ops_kernel_manager.cc][LINE:83] ERROR 07-25 03:09:53 [core.py:586] OpsManager initialize failed.[FUNC:InnerInitialize][FILE:gelib.cc][LINE:259] ERROR 07-25 03:09:53 [core.py:586] GELib::InnerInitialize failed.[FUNC:Initialize][FILE:gelib.cc][LINE:184] ERROR 07-25 03:09:53 [core.py:586] GEInitialize failed.[FUNC:GEInitialize][FILE:ge_api.cc][LINE:371] ERROR 07-25 03:09:53 [core.py:586] [Initialize][Ge]GEInitialize failed. ge result = 4294967295[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161] ERROR 07-25 03:09:53 [core.py:586] [Init][Compiler]Init compiler failed[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] ERROR 07-25 03:09:53 [core.py:586] [Set][Options]OpCompileProcessor init failed![FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] ERROR 07-25 03:09:53 [core.py:586] Process EngineCore_0: Traceback (most recent call last): File "/usr/local/python3.10.17/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/usr/local/python3.10.17/lib/python3.10/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 590, in run_engine_core raise e File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 577, in run_engine_core engine_core = EngineCoreProc(*args, **kwargs) File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 404, in __init__ super().__init__(vllm_config, executor_class, log_stats, File "/vllm-workspace/vllm/vllm/v1/engine/core.py", line 75, in __init__ self.model_executor = executor_class(vllm_config) File "/vllm-workspace/vllm/vllm/executor/executor_base.py", line 53, in __init__ self._init_executor() File "/vllm-workspace/vllm/vllm/executor/uniproc_executor.py", line 47, in _init_executor self.collective_rpc("init_device") File "/vllm-workspace/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc answer = run_method(self.driver_worker, method, args, kwargs) File "/vllm-workspace/vllm/vllm/utils/__init__.py", line 2736, in run_method return func(*args, **kwargs) File "/vllm-workspace/vllm/vllm/worker/worker_base.py", line 606, in init_device self.worker.init_device() # type: ignore File "/vllm-workspace/vllm-ascend/vllm_ascend/worker/worker_v1.py", line 132, in init_device NPUPlatform.set_device(device) File "/vllm-workspace/vllm-ascend/vllm_ascend/platform.py", line 98, in set_device torch.npu.set_device(device) File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch_npu/npu/utils.py", line 80, in set_device torch_npu._C._npu_setDevice(device_id) RuntimeError: SetPrecisionMode:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:156 NPU function error: at_npu::native::AclSetCompileopt(aclCompileOpt::ACL_PRECISION_MODE, precision_mode), error code is 500001 [ERROR] 2025-07-25-03:09:53 (PID:977, Device:0, RankID:-1) ERR00100 PTA call acl api failed [Error]: The internal ACL of the system is incorrect. Rectify the fault based on the error information in the ascend log. EC0010: [PID: 977] 2025-07-25-03:09:53.177.260 Failed to import Python module [AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead..]. Solution: Check that all required components are properly installed and the specified Python path matches the Python installation directory. (If the path does not match the directory, run set_env.sh in the installation package.) TraceBack (most recent call last): AOE Failed to call InitCannKB[FUNC:Initialize][FILE:python_adapter_manager.cc][LINE:47] Failed to initialize TeConfigInfo. [GraphOpt][InitializeInner][InitTbeFunc] Failed to init tbe.[FUNC:InitializeTeFusion][FILE:tbe_op_store_adapter.cc][LINE:1889] [GraphOpt][InitializeInner][InitTeFusion]: Failed to initialize TeFusion.[FUNC:InitializeInner][FILE:tbe_op_store_adapter.cc][LINE:1856] [SubGraphOpt][PreCompileOp][InitAdapter] InitializeAdapter adapter [tbe_op_adapter] failed! Ret [4294967295][FUNC:InitializeAdapter][FILE:op_store_adapter_manager.cc][LINE:79] [SubGraphOpt][PreCompileOp][Init] Initialize op store adapter failed, OpsStoreName[tbe-custom].[FUNC:Initialize][FILE:op_store_adapter_manager.cc][LINE:120] [FusionMngr][Init] Op store adapter manager init failed.[FUNC:Initialize][FILE:fusion_manager.cc][LINE:115] PluginManager InvokeAll failed.[FUNC:Initialize][FILE:ops_kernel_manager.cc][LINE:83] OpsManager initialize failed.[FUNC:InnerInitialize][FILE:gelib.cc][LINE:259] GELib::InnerInitialize failed.[FUNC:Initialize][FILE:gelib.cc][LINE:184] GEInitialize failed.[FUNC:GEInitialize][FILE:ge_api.cc][LINE:371] [Initialize][Ge]GEInitialize failed. ge result = 4294967295[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161] [Init][Compiler]Init compiler failed[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] [Set][Options]OpCompileProcessor init failed![FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] Traceback (most recent call last): File "/usr/local/python3.10.17/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/local/python3.10.17/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1495, in <module> uvloop.run(run_server(args)) File "/usr/local/python3.10.17/lib/python3.10/site-packages/uvloop/__init__.py", line 82, in run return loop.run_until_complete(wrapper()) File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete File "/usr/local/python3.10.17/lib/python3.10/site-packages/uvloop/__init__.py", line 61, in wrapper return await main File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1431, in run_server await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 1451, in run_server_worker async with build_async_engine_client(args, client_config) as engine_client: File "/usr/local/python3.10.17/lib/python3.10/contextlib.py", line 199, in __aenter__ return await anext(self.gen) File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 158, in build_async_engine_client async with build_async_engine_client_from_engine_args( File "/usr/local/python3.10.17/lib/python3.10/contextlib.py", line 199, in __aenter__ return await anext(self.gen) File "/vllm-workspace/vllm/vllm/entrypoints/openai/api_server.py", line 194, in build_async_engine_client_from_engine_args async_llm = AsyncLLM.from_vllm_config( File "/vllm-workspace/vllm/vllm/v1/engine/async_llm.py", line 162, in from_vllm_config return cls( File "/vllm-workspace/vllm/vllm/v1/engine/async_llm.py", line 124, in __init__ self.engine_core = EngineCoreClient.make_async_mp_client( File "/vllm-workspace/vllm/vllm/v1/engine/core_client.py", line 96, in make_async_mp_client return AsyncMPClient(*client_args) File "/vllm-workspace/vllm/vllm/v1/engine/core_client.py", line 666, in __init__ super().__init__( File "/vllm-workspace/vllm/vllm/v1/engine/core_client.py", line 403, in __init__ with launch_core_engines(vllm_config, executor_class, File "/usr/local/python3.10.17/lib/python3.10/contextlib.py", line 142, in __exit__ next(self.gen) File "/vllm-workspace/vllm/vllm/v1/engine/utils.py", line 434, in launch_core_engines wait_for_engine_startup( File "/vllm-workspace/vllm/vllm/v1/engine/utils.py", line 484, in wait_for_engine_startup raise RuntimeError("Engine core initialization failed. " RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {} [ERROR] 2025-07-25-03:09:59 (PID:707, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception 分析报错
07-26
可以帮我详细解释一下嘛def omniquant( lm, args, dataloader, act_scales, act_shifts, logger=None, ): logger.info("Starting ...") # move embedding layer and first layer to target device model = lm.model dev = lm.device use_cache = model.config.use_cache model.config.use_cache = False is_llama = False if "llama" in args.net.lower(): is_llama = True layers = model.model.layers model.model.embed_tokens = model.model.embed_tokens.to(dev) model.model.norm = model.model.norm.to(dev) DecoderLayer = QuantLlamaDecoderLayer pairs = { "q_proj":"qkv", "o_proj":"out", "up_proj":"fc1" } layer_name_prefix = "model.layers" elif "opt" in args.net.lower(): layers = model.model.decoder.layers model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev) if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: model.model.decoder.project_out = model.model.decoder.project_out.to(dev) if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: model.model.decoder.project_in = model.model.decoder.project_in.to(dev) DecoderLayer = QuantOPTDecoderLayer pairs = { "q_proj":"qkv", "out_proj":"out", "fc1":"fc1" } layer_name_prefix = "model.decoder.layers" elif "falcon" in args.net.lower(): layers = model.transformer.h model.transformer.word_embeddings.to(dev) model.transformer.ln_f.to(dev) model.lm_head.to(dev) DecoderLayer = QuantFalconDecoderLayer layer_name_prefix = "model.transformer.h" elif 'mixtral' in args.net.lower(): is_llama = True # same to llama except ffn layers = model.model.layers model.model.embed_tokens = model.model.embed_tokens.to(dev) model.model.norm = model.model.norm.to(dev) layer_name_prefix = "model.layers" else: raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now") layers[0] = layers[0].to(dev) if args.deactive_amp and args.epochs>0: dtype = torch.float traincast = nullcontext else: dtype = torch.float16 traincast = torch.cuda.amp.autocast inps = torch.zeros( (args.nsamples, lm.seqlen, model.config.hidden_size), dtype=dtype, device=dev ) cache = {"i": 0} # catch the first layer input class Catcher(nn.Module): def __init__(self, module): super().__init__() self.module = module self.is_llama = False def forward(self, inp, **kwargs): inps[cache["i"]] = inp cache["i"] += 1 cache["attention_mask"] = kwargs["attention_mask"] if self.is_llama: cache["position_ids"] = kwargs["position_ids"] raise ValueError layers[0] = Catcher(layers[0]) layers[0].is_llama = is_llama with torch.no_grad(): for batch in dataloader: if cache["i"] >= args.nsamples: break try: model(batch[0].to(dev)) except ValueError: pass # move embedding layer and first layer to cpu layers[0] = layers[0].module layers[0] = layers[0].cpu() if "llama" in args.net.lower() or "mixtral" in args.net.lower(): model.model.embed_tokens = model.model.embed_tokens.cpu() model.model.norm = model.model.norm.cpu() elif "opt" in args.net.lower(): model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu() model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu() if hasattr(model.model.decoder, "project_out") and model.model.decoder.project_out: model.model.decoder.project_out = model.model.decoder.project_out.cpu() if hasattr(model.model.decoder, "project_in") and model.model.decoder.project_in: model.model.decoder.project_in = model.model.decoder.project_in.cpu() elif 'falcon' in args.model: model.transformer.word_embeddings = model.transformer.word_embeddings.cpu() else: raise ValueError("Only support for opt/llama/Llama-2/falcon/mixtral now") torch.cuda.empty_cache() # same input of first layer for fp model and quant model quant_inps = inps fp_inps = copy.deepcopy(inps) # take output of fp model as input fp_inps_2 = copy.deepcopy(inps) if args.aug_loss else None # take output of quantization model as input attention_mask = cache["attention_mask"] if attention_mask is not None: attention_mask_batch = attention_mask.repeat(args.batch_size,1,1,1) if args.deactive_amp else attention_mask.repeat(args.batch_size,1,1,1).float() else: logger.info( "No attention mask caught from the first layer." " Seems that model's attention works without a mask." ) attention_mask_batch = None loss_func = torch.nn.MSELoss() if is_llama: position_ids = cache["position_ids"] else: position_ids = None if args.resume: omni_parameters = torch.load(args.resume) else: omni_parameters = {} for i in range(len(layers)): logger.info(f"=== Start quantize layer {i} ===") layer = layers[i].to(dev) if "mixtral" in args.net.lower(): # for mixtral, we only leverage lwc, which can be achieve by simply replace Linear with QuantLinear qlayer = copy.deepcopy(layer) for name, module in qlayer.named_modules(): if isinstance(module,torch.nn.Linear) and not "gate" in name: # do not quantize gate quantlinear = QuantLinear(module, args.weight_quant_params, args.act_quant_params) add_new_module(name, qlayer, quantlinear) else: qlayer = DecoderLayer(lm.model.config, layer, args) qlayer = qlayer.to(dev) # obtain output of full-precision model set_quant_state(qlayer, weight_quant=False, act_quant=False) if args.epochs > 0: with torch.no_grad(): with torch.cuda.amp.autocast(): for j in range(args.nsamples): fp_inps[j] = qlayer(fp_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] if args.aug_loss: fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] # init smooth parameters set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward qlayer.let = args.let use_shift = True if is_llama or args.abits == 16: use_shift = False # deactivate channel-wise shifting for llama model and weight-only quantization if args.let: # init channel-wise scaling and shift qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.ones(layer.self_attn.q_proj.out_features,device=dev, dtype=dtype))) for name,module in qlayer.named_modules(): if isinstance(module, QuantLinear): for key in pairs.keys(): if key in name: act = act_scales[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype).clamp(min=1e-5) weight = module.weight.abs().max(dim=0)[0].clamp(min=1e-5) scale = (act.pow(args.alpha)/weight.pow(1-args.alpha)).clamp(min=1e-5) if use_shift and not is_llama: shift = act_shifts[f"{layer_name_prefix}.{i}.{name}"].to(device=dev, dtype=dtype) else: shift = torch.zeros_like(scale) qlayer.register_parameter(f"{pairs[key]}_smooth_shift",torch.nn.Parameter(shift)) qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(scale)) if args.resume: qlayer.load_state_dict(omni_parameters[i], strict=False) if args.epochs > 0: with torch.no_grad(): qlayer.float() # required for AMP training # create optimizer optimizer = torch.optim.AdamW( [{"params":let_parameters(qlayer, use_shift),"lr":args.let_lr}, {"params":lwc_parameters(qlayer),"lr":args.lwc_lr}],weight_decay=args.wd) loss_scaler = utils.NativeScalerWithGradNormCount() for epochs in range(args.epochs): loss_list = [] norm_list = [] for j in range(args.nsamples//args.batch_size): index = j * args.batch_size # obtain output of quantization model with traincast(): smooth_and_quant_temporary(qlayer, args, is_llama) quant_out = qlayer(quant_inps[index:index+args.batch_size,], attention_mask=attention_mask_batch,position_ids=position_ids)[0] loss = loss_func(fp_inps[index:index+args.batch_size,], quant_out) if args.aug_loss: loss += loss_func(fp_inps_2[index:index+args.batch_size,], quant_out) if not math.isfinite(loss.item()): logger.info("Loss is NAN, stopping training") pdb.set_trace() loss_list.append(loss.detach().cpu()) optimizer.zero_grad() norm = loss_scaler(loss, optimizer,parameters= get_omni_parameters(qlayer, use_shift)).cpu() norm_list.append(norm.data) loss_mean = torch.stack(loss_list).mean() norm_mean = torch.stack(norm_list).mean() logger.info(f"layer {i} iter {epochs} loss:{loss_mean} norm:{norm_mean} max memory_allocated {torch.cuda.max_memory_allocated(lm._device) / 1024**2} ") clear_temp_variable(qlayer) del optimizer qlayer.half() # real smooth and quantization smooth_and_quant_inplace(qlayer, args, is_llama) if args.epochs>0: # update input of quantization model with torch.no_grad(): # with torch.cuda.amp.autocast(): with traincast(): for j in range(args.nsamples): quant_inps[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] register_scales_and_zeros(qlayer) layers[i] = qlayer.to("cpu") omni_parameters[i] = omni_state_dict(qlayer) torch.save(omni_parameters, os.path.join(args.output_dir, f"omni_parameters.pth")) else: register_scales_and_zeros(qlayer) layers[i] = qlayer.to("cpu") if args.real_quant: assert args.wbits in [2,3,4] and args.abits >= 16 # only support weight-only quantization named_linears = get_named_linears(qlayer) for name, module in named_linears.items(): scales = module.weight_quantizer.scales zeros = module.weight_quantizer.zeros group_size = module.weight_quantizer.group_size dim0 = module.weight.shape[0] scales = scales.view(dim0,-1) zeros = zeros.view(dim0,-1) if args.wbits == 3: q_linear = qlinear_cuda.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None) else: q_linear = qlinear_triton.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None) q_linear.pack(module.cpu(), scales.float().cpu(), zeros.float().cpu()) add_new_module(name, qlayer, q_linear) print(f"pack quantized {name} finished") del module del layer torch.cuda.empty_cache() del inps del quant_inps del fp_inps del fp_inps_2 torch.cuda.empty_cache() gc.collect() model.config.use_cache = use_cache return model
08-20
import torch import triton import triton.language as tl @triton.jit def fused_bmm_softmax_kernel( # 输入/输出指针 a_ptr, b_ptr, c_ptr, # 张量维度 B, M, N, K, # 张量A步长 (batch, row, col) stride_ab, stride_am, stride_ak, # 张量B步长 (batch, col, col) stride_bb, stride_bk, stride_bn, # 输出C步长 (batch, row, col) stride_cb, stride_cm, stride_cn, # 分块参数 BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): # 程序ID对应输出矩阵中的位置 (batch, row) batch_id = tl.program_id(0) row_id = tl.program_id(1) # 初始化行最大值为负无穷 row_max = tl.full((), float('-inf'), dtype=tl.float32) # Pass 1: 计算当前行的最大值 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 为当前列块预加载A的部分数据 a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn # 累加部分矩阵乘法结果 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 更新行最大值 max_val = tl.max(accumulator) row_max = tl.maximum(row_max, max_val) # Pass 2: 计算指数和 (exp_sum) exp_sum = tl.zeros((), dtype=tl.float32) for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): # 重新加载数据(与Pass 1类似) a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算稳定指数值并累加 exp_vals = tl.exp(accumulator - row_max) exp_sum += tl.sum(exp_vals) # Pass 3: 计算归一化结果并写入 for n_block_idx in range(0, tl.cdiv(N, BLOCK_SIZE_N)): a_ptrs = a_ptr + batch_id * stride_ab + row_id * stride_am + \ tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_ak b_ptrs = b_ptr + batch_id * stride_bb + \ tl.arange(0, BLOCK_SIZE_K)[:, None] * stride_bk + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :]) * stride_bn accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=tl.arange(0, BLOCK_SIZE_K)[:, None] < K - k * BLOCK_SIZE_K, other=0.0) partial = tl.dot(a, b) accumulator += partial a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # 计算最终Softmax结果 exp_vals = tl.exp(accumulator - row_max) softmax_output = exp_vals / exp_sum # 写入输出 c_ptrs = c_ptr + batch_id * stride_cb + row_id * stride_cm + \ (n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) * stride_cn tl.store(c_ptrs, softmax_output) 封装函数供PyTorch调用 def fused_bmm_softmax(a: torch.Tensor, b: torch.Tensor): B, M, K = a.shape B, K, N = b.shape c = torch.empty((B, M, N), device=a.device, dtype=a.dtype) # 配置分块大小 (需根据硬件调整) BLOCK_SIZE_M = 1 # 每个程序实例处理1行 BLOCK_SIZE_N = 64 # 列方向分块大小 BLOCK_SIZE_K = 32 # K维度分块大小 grid = (B, M) # 每个批次每行一个程序实例 fused_bmm_softmax_kernel[grid]( a, b, c, B, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), b.stride(2), c.stride(0), c.stride(1), c.stride(2), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K ) return c把上面的融合kernel跟据下面的代码的思路写一个配套的代码(完全与融合kernel配套,只是代码思路类似而已):import os import sys import random import itertools import csv from triton import Config import torch import triton import torch_mlu import torch.nn as nn import triton.language as tl os.environ[“TRITON_PRINT_AUTOTUNING”] = “1” def get_config(): row_block = list(range(1, 33)) num_stages = [1, 2, 3, 4] configs = list(itertools.product(row_block, num_stages)) random.shuffle(configs) return configs @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, OUTER_ROW_BLOCK: tl.constexpr, ROW_BLOCK: tl.constexpr, num_stages: tl.constexpr): # The rows of the softmax are independent, so we parallelize across rows. pid = tl.program_id(0) ub = tl.minimum(n_rows, (tl.program_id(0) + 1) * OUTER_ROW_BLOCK) outer_raw_idx = tl.program_id(0) * OUTER_ROW_BLOCK for inner_row in range(outer_raw_idx, ub, ROW_BLOCK): row_offsets = tl.arange(0, ROW_BLOCK) row_mask = (inner_row + row_offsets < ub)[:, None] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + (inner_row + row_offsets) * input_row_stride col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] input_ptrs = row_start_ptr[:, None] + col_offsets # Load rows into SRAM row = tl.load(input_ptrs, mask=row_mask).to(tl.float32) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=1)[:, None] # Note that exponentiation in Triton is fast but approximate numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = numerator / denominator softmax_output = softmax_output.to(tl.float16) # Write back output to DRAM output_row_start_ptr = output_ptr + (inner_row + row_offsets) * output_row_stride output_ptrs = output_row_start_ptr[:, None] + col_offsets tl.store(output_ptrs, softmax_output, mask=row_mask) def triton_softmax_perf(configs, x): n_rows, n_cols = x.shape BLOCK_SIZE = n_cols # Allocate output y = torch.empty_like(x) # Enqueue kernel. We split all rows into 48 parts evenly. OUTER_ROW_BLOCK = triton.cdiv(n_rows, 48) grid = lambda META: (48, 1, 1) softmax_kernel[(grid)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, OUTER_ROW_BLOCK=OUTER_ROW_BLOCK, ROW_BLOCK=configs[0], num_stages=configs[1]) return y def benchmark_softmax(configs, M, K, warmup, repeat): # 创建输入数据 a = torch.randn(M, K, dtype=torch.float16, device=‘mlu’).contiguous() trainset, count, maxsampled = [], 0, 20 for i in range(len(configs)): try: # filter invalid settings c_triton = triton_softmax_perf(configs[i], a) except: continue for _ in range(warmup): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() start_event = torch.mlu.Event(enable_timing=True) end_event = torch.mlu.Event(enable_timing=True) start_event.record() for _ in range(repeat): c_triton = triton_softmax_perf(configs[i], a) torch.mlu.synchronize() end_event.record() torch.mlu.synchronize() trainset.append(list(configs[i])) trainset[count].append(start_event.hardware_time(end_event) / repeat / 1000) print('train data {}: {}'.format(count, trainset[count])) count += 1 if count >= maxsampled: break with open('data/{}_{}.csv'.format(M, K), 'w', newline='') as f: writer = csv.writer(f) writer.writerows(trainset) def main(): micro_shapes = [ (80, 768), (604, 768), (1216, 768), (1968, 768), ] configs = get_config() for M, K in micro_shapes: print('Processing input {} {}'.format(M, K)) benchmark_softmax(configs, M, K, warmup=10, repeat=20) if name==“main”: main()
06-12
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值