php源码之array_key_exists

本文详细解析了PHP函数array_key_exists的源码实现过程,包括参数解析、字符串和整数类型的键值检测流程,以及 zend_hash_index_exists 和 zend_hash_exists_ind 函数的工作原理。
部署运行你感兴趣的模型镜像

先定一个flag,以后争取每天理解一个php函数源码,熟悉php内核。

函数array_key_exists等数组相关的函数都定义在源码ext/standard/array.c文件中。

array_key_exists函数的作用是检测给定数组array中是否有键为key的值。源码如下:

PHP_FUNCTION(array_key_exists)
{
	zval *key;					/* key to check for */
	HashTable *array;			/* array to check in */

	ZEND_PARSE_PARAMETERS_START(2, 2)
		Z_PARAM_ZVAL(key)
		Z_PARAM_ARRAY_OR_OBJECT_HT(array)
	ZEND_PARSE_PARAMETERS_END();

	switch (Z_TYPE_P(key)) {
		case IS_STRING:
			if (zend_symtable_exists_ind(array, Z_STR_P(key))) {
				RETURN_TRUE;
			}
			RETURN_FALSE;
		case IS_LONG:
			if (zend_hash_index_exists(array, Z_LVAL_P(key))) {
				RETURN_TRUE;
			}
			RETURN_FALSE;
		case IS_NULL:
			if (zend_hash_exists_ind(array, ZSTR_EMPTY_ALLOC())) {
				RETURN_TRUE;
			}
			RETURN_FALSE;

		default:
			php_error_docref(NULL, E_WARNING, "The first argument should be either a string or an integer");
			RETURN_FALSE;
	}
}

如上,php7中采用zend_parse_parameters_start来代替zend_parse_parameters函数获取参数值,具体实现我还没有弄明白,后面补上,反正作用和zend_parse_parameters一样,之前也有写关于它的博客。

指针变量key指向一个zval结构体,通过判断这个key是字符串类型还是整型分别处理。

如果key指向的zval变量是string类型:

调用函数zend_symtabl_exists_ind函数,传入的第二个参数Z_STR_P(key)展开就是*(key).value.str,不熟悉zval结构的可以自行google一下。函数zend_symtabl_exists_ind定义在zend_hash.h中,代码如下:

static zend_always_inline int zend_symtable_exists_ind(HashTable *ht, zend_string *key)
{
	zend_ulong idx;

	if (ZEND_HANDLE_NUMERIC(key, idx)) {
		return zend_hash_index_exists(ht, idx);
	} else {
		return zend_hash_exists_ind(ht, key);
	}
}

这里面再次先通过函数zend_handle_numeric对key做判断,看这个字符串是不是数字类型的,比如说‘1234’,‘23’这种字符串。判断代码如下:

static zend_always_inline int _zend_handle_numeric_str(const char *key, size_t length, zend_ulong *idx)
{
	const char *tmp = key;

	if (*tmp > '9') {
		return 0;
	} else if (*tmp < '0') {
		if (*tmp != '-') {
			return 0;
		}
		tmp++;
		if (*tmp > '9' || *tmp < '0') {
			return 0;
		}
	}
	return _zend_handle_numeric_str_ex(key, length, idx);
}

这里比较*tmp > '9'其实是比较的ascii码。如果key确实是字符串数字,那么返回函数zend_handle_numeric_str_ex的结果:

ZEND_API int ZEND_FASTCALL _zend_handle_numeric_str_ex(const char *key, size_t length, zend_ulong *idx)
{
	register const char *tmp = key;

	const char *end = key + length;

	if (*tmp == '-') {
		tmp++;
	}

	if ((*tmp == '0' && length > 1) /* numbers with leading zeros */
	 || (end - tmp > MAX_LENGTH_OF_LONG - 1) /* number too long */
	 || (SIZEOF_ZEND_LONG == 4 &&
	     end - tmp == MAX_LENGTH_OF_LONG - 1 &&
	     *tmp > '2')) { /* overflow */
		return 0;
	}
	*idx = (*tmp - '0');
	while (1) {
		++tmp;
		if (tmp == end) {
			if (*key == '-') {
				if (*idx-1 > ZEND_LONG_MAX) { /* overflow */
					return 0;
				}
				*idx = 0 - *idx;
			} else if (*idx > ZEND_LONG_MAX) { /* overflow */
				return 0;
			}
			return 1;
		}
		if (*tmp <= '9' && *tmp >= '0') {
			*idx = (*idx * 10) + (*tmp - '0');
		} else {
			return 0;
		}
	}
}

这个函数的作用就是将字符串数字,通过逐个的获取然后计算出它的整型值,如‘23’=》2*10+3*1(当然代码中没这么简单,我这里是打个比方说这个函数的作用),最后赋值给idx变量。

然后到了最关键的一步,检测idx这个key是否存在于数组ht中了。zend_hash_index_exists(ht,idx);

ZEND_API zend_bool ZEND_FASTCALL zend_hash_index_exists(const HashTable *ht, zend_ulong h)
{
	Bucket *p;

	IS_CONSISTENT(ht);

	if (ht->u.flags & HASH_FLAG_PACKED) {
		if (h < ht->nNumUsed) {
			if (Z_TYPE(ht->arData[h].val) != IS_UNDEF) {
				return 1;
			}
		}
		return 0;
	}

	p = zend_hash_index_find_bucket(ht, h);
	return p ? 1 : 0;
}

这里有一个宏HASH_FLAG_PACKED,为真就代表当前数组的key都是系统生成的,也就是说是按从0到1,2,3等等按序排列的,所以判读键为key的是否存在,直接检查arData数组中第idx个元素是否有定义就行了,这里不涉及什么hash查找,冲突解决等一系列问题。但如果HASH_FLAG_PACKED为假,那么肯定就需要先计算idx的hash值,找到key为idx的数据应该在arData的第几位才行。这就要通过函数zend_hash_index_find_bucket了。

static zend_always_inline Bucket *zend_hash_index_find_bucket(const HashTable *ht, zend_ulong h)
{
	uint32_t nIndex;
	uint32_t idx;
	Bucket *p, *arData;

	arData = ht->arData;
	nIndex = h | ht->nTableMask;
	idx = HT_HASH_EX(arData, nIndex);
	while (idx != HT_INVALID_IDX) {
		ZEND_ASSERT(idx < HT_IDX_TO_HASH(ht->nTableSize));
		p = HT_HASH_TO_BUCKET_EX(arData, idx);
		if (p->h == h && !p->key) {
			return p;
		}
		idx = Z_NEXT(p->val);
	}
	return NULL;
}

这里需要明白一点,数字的哈希值就等于他本身,所以才有不计算h的哈希值,就执行h | ht->nTableMask。

然后处理一下冲突,最后得出key为idx的数据是否存在于数组中。

如果idx确确实实是字符串,那么思路更简单一点,最后通过zen_hash_find_bucket来判断是否存在,与上面zend_hash_index_find_bucket不同的是,函数中要先计算字符串key的哈希值,然后再执行h | ht->nTableMask。

因为篇幅有限,就写到这里了,如果有不正确的地方烦请指正!












您可能感兴趣的与本文相关的镜像

Llama Factory

Llama Factory

模型微调
LLama-Factory

LLaMA Factory 是一个简单易用且高效的大型语言模型(Large Language Model)训练与微调平台。通过 LLaMA Factory,可以在无需编写任何代码的前提下,在本地完成上百种预训练模型的微调

import json import re import logging import sys from pathlib import Path from shutil import copy2 from argparse import ArgumentParser class CLMRangeSynchronizer: def __init__(self, manifest_path="output/generated_ranges_manifest.json", c_file_path="input/wlc_clm_data_6726b0.c", dry_run=False): self.manifest_path = Path(manifest_path) self.c_file_path = Path(c_file_path) self.dry_run = dry_run # 是否为预览模式 if not self.manifest_path.exists(): raise FileNotFoundError(f"找不到 manifest 文件: {self.manifest_path}") if not self.c_file_path.exists(): raise FileNotFoundError(f"找不到 C 源文件: {self.c_file_path}") self.used_ranges = [] self.array_macros = {} self.struct_entries = {} self.start_marker = "// === START: CHANNEL RANGES" self.end_marker = "// === END: CHANNEL RANGES" def load_manifest(self): """加载并解析 generated_ranges_manifest.json""" with open(self.manifest_path, 'r', encoding='utf-8') as f: data = json.load(f) if "used_ranges" not in data: raise KeyError("❌ manifest 文件缺少 'used_ranges' 字段") valid_ranges = [] for item in data["used_ranges"]: if isinstance(item, str) and re.match(r'^RANGE_[\w\d_]+_\d+_\d+$', item): valid_ranges.append(item) else: logger.warning(f"跳过无效项: {item}") self.used_ranges = sorted(set(valid_ranges)) logger.info(f"已加载 {len(self.used_ranges)} 个有效 RANGE 宏") def parse_c_arrays(self): """解析 C 文件中的 channel_ranges_xxx[] 数组""" content = self.c_file_path.read_text(encoding='utf-8') start_idx = content.find(self.start_marker) end_idx = content.find(self.end_marker) if start_idx == -1 or end_idx == -1: raise ValueError("未找到 CHANNEL RANGES 注释锚点") block = content[start_idx:end_idx] pattern = re.compile( r'static\s+const\s+struct\s+clm_channel_range\s+(channel_ranges_[\w\d]+)\s*\[\s*\]\s*=\s*\{([^}]*)\};', re.DOTALL ) for array_name, body in pattern.findall(block): macros = [m.strip() for m in re.findall(r'RANGE_[\w\d_]+', body)] entries = [] for low, high in re.findall(r'\{\s*(\d+)\s*,\s*(\d+)\s*\}', body): entries.append({"low": int(low), "high": int(high)}) self.array_macros[array_name] = macros self.struct_entries[array_name] = entries logger.debug(f"解析数组 {array_name}: {len(macros)} 个宏") def get_array_name_for_range(self, range_macro): """根据 RANGE 宏推断应属于哪个数组""" # 提取 band (如 2g, 5g) 和 bw (如 20, 40, 80) match = re.match(r'RANGE_([0-9]+[A-Za-z])_([0-9]+)M_', range_macro, re.IGNORECASE) if not match: logger.warning(f"无法推断数组名: {range_macro}") return None band = match.group(1).lower() # '2g' bw = match.group(2) # '20' return f"channel_ranges_{band}_{bw}m" def extract_channels_from_macro(self, macro): """ 从类似 RANGE_2G_20M_1_1 的宏中提取 low 和 high 信道号 支持任何形式:只要最后是 _数字_数字 就行 """ match = re.search(r'_(\d+)_(\d+)$', macro) if match: low = int(match.group(1)) high = int(match.group(2)) return low, high logger.warning(f"宏格式错误,无法提取信道: {macro}") return None, None def validate_and_repair(self): """确保每个 used_range 都在正确的数组中""" modified = False changes = [] for range_macro in self.used_ranges: array_name = self.get_array_name_for_range(range_macro) if not array_name: logger.warning(f"无法识别数组类型,跳过: {range_macro}") continue if array_name not in self.array_macros: logger.info(f"创建新数组: {array_name}") self.array_macros[array_name] = [] self.struct_entries[array_name] = [] if range_macro not in self.array_macros[array_name]: low, high = self.extract_channels_from_macro(range_macro) if low is not None and high is not None: self.array_macros[array_name].append(range_macro) self.struct_entries[array_name].append({"low": low, "high": high}) changes.append(f"添加: {range_macro} → {{{low}, {high}}} 到 {array_name}") logger.info(f"将添加: {range_macro} → {{{low}, {high}}} 到 {array_name}") modified = True else: logger.warning(f"无法解析信道范围: {range_macro}") if modified and not self.dry_run: self._write_back_in_block() logger.info("C 文件已更新") elif modified and self.dry_run: logger.info("DRY-RUN MODE: 有变更但不会写入文件") else: logger.info("所有 RANGE 已存在,无需修改") return modified def _infer_array_from_enum(self, enum_decl): match = re.search(r'enum\s+range_([a-z\d_]+)', enum_decl) if match: return f"channel_ranges_{match.group(1)}" return None def _format_array_body(self, macros, structs, indent=" "): """ 将宏和结构体信息格式化为字符串数组形式。 Args: macros (list): 宏名称列表。 structs (list): 包含每个结构体信息的字典列表,每个字典包含'low'和'high'键。 indent (str): 每行前的缩进字符串,默认为" "。 Returns: str: 格式化后的字符串数组。 """ lines = [] for i in range(len(macros)): line = f"{indent}{{ {structs[i]['low']}, {structs[i]['high']} }}, /* {macros[i]} */" lines.append(line) return "\n".join(lines) def _write_back_in_block(self): """安全地一次性更新所有数组和枚举""" content = self.c_file_path.read_text(encoding='utf-8') start_idx = content.find(self.start_marker) end_idx = content.find(self.end_marker) + len(self.end_marker) header = content[:start_idx] footer = content[end_idx:] block = content[start_idx:end_idx] new_block = block array_pattern = re.compile( r'(static\s+const\s+struct\s+clm_channel_range\s+(channel_ranges_[\w\d]+)\s*\[\s*\]\s*=\s*\{)([^}]*)\};', re.DOTALL ) for match in array_pattern.finditer(block): array_name = match.group(2) if array_name not in self.array_macros: continue macros = self.array_macros[array_name] structs = self.struct_entries[array_name] formatted_body = self._format_array_body(macros, structs) old_decl = match.group(0) new_decl = f"{match.group(1)}\n{formatted_body}\n}};" new_block = new_block.replace(old_decl, new_decl) enum_pattern = re.compile(r'(enum\s+range_[\w\d_]+\s*\{)([^}]*)\}(;)', re.DOTALL) final_block = new_block for match in enum_pattern.finditer(new_block): enum_body = match.group(2).strip() existing_macros = dict(re.findall(r'(RANGE_[\w\d_]+)\s*=\s*(\d+)', enum_body)) array_name = self._infer_array_from_enum(match.group(0)) if not array_name or array_name not in self.array_macros: continue expected_macros = self.array_macros[array_name] current_ids = [int(v) for v in existing_macros.values()] next_id = max(current_ids) + 1 if current_ids else 0 missing_macros = [m for m in expected_macros if m not in existing_macros] if not missing_macros: continue insert_lines = [f" {macro} = {next_id + i}," for i, macro in enumerate(missing_macros)] to_insert = "\n" + "\n".join(insert_lines) if enum_body.endswith(","): to_insert = "\n".join(insert_lines) final_block = final_block[:match.end(2)] + "," + to_insert + final_block[match.end(2):] else: final_block = final_block[:match.end(2)] + to_insert + final_block[match.end(2):] # ✅ 只有非 dry-run 才写文件 if self.dry_run: logger.info("DRY-RUN: 跳过写入文件") return backup_path = self.c_file_path.with_suffix('.c.bak') copy2(self.c_file_path, backup_path) logger.info(f"已备份 → {backup_path}") self.c_file_path.write_text(header + final_block + footer, encoding='utf-8') logger.info(f"已保存: {self.c_file_path}") def run(self): logger.info("开始同步 CLM RANGE 定义...") try: self.load_manifest() self.parse_c_arrays() was_modified = self.validate_and_repair() logger.info("同步完成" + ("(有修改)" if was_modified else "(无变更)")) except Exception as e: logger.error(f"同步失败: {e}") raise # ======================== # 命令行入口 # ======================== def main(): # ✅ 所有 logging 配置都在这里完成 logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', force=True # 👈 每次运行都强制重置 logging ) logger = logging.getLogger(__name__) # === 固定配置(你可以根据需要修改这些路径)=== manifest_path = "output/generated_ranges_manifest.json" c_file_path = "input/wlc_clm_data_6726b0.c" dry_run = False # 设为 True 可预览变更而不写入文件 log_level = "INFO" # 可选: "DEBUG", "INFO", "WARNING", "ERROR" # 设置日志级别 logging.getLogger().setLevel(log_level) print(f"📌 开始同步 RANGE 定义...") print(f"📊 manifest 文件: {manifest_path}") print(f"📄 C 源文件: {c_file_path}") if dry_run: print("🟡 启用 dry-run 模式:仅预览变更,不修改文件") try: sync = CLMRangeSynchronizer( manifest_path=manifest_path, c_file_path=c_file_path, dry_run=dry_run ) sync.run() print("✅ 同步完成!") except FileNotFoundError as e: logger.error(f"文件未找到: {e}") print("❌ 请检查文件路径是否正确。") sys.exit(1) except PermissionError as e: logger.error(f"权限错误: {e}") print("❌ 无法读取或写入文件,请检查权限。") sys.exit(1) except Exception as e: logger.error(f"程序异常退出: {e}", exc_info=True) sys.exit(1) if __name__ == '__main__': main() 欧克,在源码基础上更改
最新发布
10-16
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值