Flask源码解析:字符串方法endswith与os.path.splittext()

本文介绍了Python中使用endswith方法判断文件扩展名的技巧,以及如何利用os模块的split和splitext函数来分离文件路径和文件名,适用于进行文件管理和处理的开发者。

1、字符串方法endswith

  endswith方法:

def endswith(self, suffix, start=None, end=None): # real signature unknown; restored from __doc__
	"""
	S.endswith(suffix[, start[, end]]) -> bool
	
	Return True if S ends with the specified suffix, False otherwise.
	With optional start, test S beginning at that position.
	With optional end, stop comparing S at that position.
	suffix can also be a tuple of strings to try.
	"""
	return False

  其中suffix支持字符串构成的元组(tuple)。

filename = "1.html"
filename = "1.xml"
# def endswith(self, suffix, start=None, end=None):
# suffix can also be a tuple of strings to try.
r = filename.endswith(('.html', '.htm', '.xml', '.xhtml'))
print(r)

  输出结果为:False

2、os.path.splittext

import os
fn = "F:/2018/python/review/project/Flask_01/__get__与__set__.py"

# os.path.splitext
# f: __get__与__set__.py
f = os.path.basename(fn)
print(f)
# r: ('__get__与__set__', '.py')
r = os.path.splitext(f)
print(r)

# os.path.split
dirname,filename=os.path.split('/home/ubuntu/python_coding/split_func/split_function.py')
# dirname: /home/ubuntu/python_coding/split_func
# filename: split_function.py
print(dirname, filename)

  os.path.splittext():将文件名和扩展名分开,返回由文件名和扩展名构成的元组

  os.path.split():返回文件的路径和文件名构成的元组

  

  

转载于:https://www.cnblogs.com/bad-robot/p/10062284.html

import os import json import requests from flask import Flask, jsonify, request, render_template from datetime import datetime import logging import glob import time import argparse import re app = Flask(__name__) # 配置日志 logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 定义基础路径 BASE_PATH = r"C:\Users\l30078648\Desktop\250730" DOWNLOAD_PATH = os.path.join(BASE_PATH, "downloads") # S3 OBS API配置 S3_OBS_API_URL = "https://adsci.api.ias.huawei.com/icollect/get_s3_obs_file_download_url/" S3_OBS_HEADERS = { "Accept-Language": "zh-CN", "Content-Type": "application/json" } USER_ID = "h00512356" # 芯片型号到文件后缀的映射 CHIP_FILENAME_MAP = { "BS9SX1A": "BS9SX1AA", # 添加额外的'A'后缀 "Ascend610Lite": "Ascend610Lite" } # 定义芯片和模型组路径映射 TRACK_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "prediction_nn") } } JSON_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "prediction_nn") } } def send_request(method, url, headers=None, data=None, timeout=30, max_try_times=1, **kwargs): """发送HTTP请求(带重试机制)""" try_times = 0 response = None if data is None: data = {} if 'files' in kwargs: request_body = data else: request_body = json.dumps(data) if not isinstance(data, str) else data.encode("utf-8") while try_times < max_try_times: try: response = requests.request( method=method, url=url, headers=headers, data=request_body, timeout=timeout, verify=False, **kwargs ) if response.status_code in (200, 201): break except requests.exceptions.Timeout: logger.warning(f"请求超时: url={url}") except Exception as e: logger.exception(f"请求失败: url={url}, 错误: {str(e)}") if response: logger.error(f"响应内容: {response.text}") try_times += 1 logger.warning(f"请求失败, url={url}, 重试中 ({try_times}/{max_try_times})") if max_try_times > 1: time.sleep(5) return response def extract_build_ids_from_track(chip, group): """从.track文件夹中提取build_id数据""" build_data = {} build_missing = [] if chip not in TRACK_PATHS or group not in TRACK_PATHS[chip]: logger.error(f"无效路径: {chip}/{group}") return build_data, build_missing group_path = TRACK_PATHS[chip][group] if not os.path.exists(group_path): logger.error(f"原始路径不存在: {group_path}") return build_data, build_missing # 扫描所有JSON文件 json_files = glob.glob(os.path.join(group_path, "*.json")) logger.info(f"在路径 {group_path} 中找到 {len(json_files)} 个JSON文件") for json_file in json_files: try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) model_name = os.path.splitext(os.path.basename(json_file))[0] # 只提取build_id build_id = data.get('build_id') # 处理不同格式的ID if isinstance(build_id, int): build_id = str(build_id) # 存储build_id build_data[model_name] = build_id if build_id: logger.debug(f"提取成功: {model_name} -> build_id: {build_id}") else: logger.warning(f"文件 {json_file} 中没有找到build_id字段") build_missing.append(model_name) except Exception as e: logger.error(f"解析原始文件 {json_file} 时出错: {str(e)}") build_missing.append(os.path.basename(json_file)) continue logger.info(f"成功提取 {len(build_data) - len(build_missing)} 个build_id") return build_data, build_missing def get_chip_filename(chip): """获取芯片对应的正确文件名""" # 使用映射表获取正确的文件名后缀 return CHIP_FILENAME_MAP.get(chip, f"{chip}") def download_file_with_build_id(model_name, chip, group, build_id): """使用build_id下载JSON文件""" if not build_id: logger.warning(f"跳过缺少build_id的文件: {model_name}") return None, None # 获取芯片对应的正确文件名 chip_filename = get_chip_filename(chip) # 构建S3 OBS路径 s3_obs_url = f"hpct/nn_model_prebuild/{build_id}_PREBUILD_OUTPUTS/profiling_summary/" file_name = f"{chip_filename}_prof_record.json" # 获取下载URL download_url, error = get_download_url(build_id, s3_obs_url, file_name) if not download_url: logger.error(f"获取下载URL失败: {error}") return None, None # 下载并保存文件 save_path, file_data = download_and_save(download_url, model_name, chip, group, file_name) return save_path, file_data def get_download_url(build_id, s3_obs_url, file_name): """获取S3 OBS下载URL""" payload = { "user_id": USER_ID, "s3_obs_url": s3_obs_url, "file_name": file_name } logger.info(f"请求下载URL: {payload}") response = send_request( "POST", S3_OBS_API_URL, headers=S3_OBS_HEADERS, data=payload ) if not response or response.status_code != 200: logger.error(f"获取下载URL失败: 状态码={getattr(response, 'status_code', '无响应')}") return None, "下载URL获取失败" try: result = response.json() download_url = result.get("download_url") if not download_url: logger.error(f"响应中缺少download_url字段: {result}") return None, "响应中缺少download_url字段" # 修复URL端口问题:确保使用10001端口 fixed_url = fix_download_url_port(download_url) return fixed_url, None except Exception as e: logger.error(f"解析响应失败: {str(e)}") return None, f"解析响应失败: {str(e)}" def fix_download_url_port(url): """修正下载URL的端口号""" # 使用正则表达式匹配并替换端口号 pattern = r"(\w+\.huawei\.com):\d+" replacement = r"\1:10001" fixed_url = re.sub(pattern, replacement, url) if fixed_url != url: logger.info(f"URL端口已修正: {url} -> {fixed_url}") return fixed_url def download_and_save(download_url, model_name, chip, group, file_name): """下载文件并保存到本地""" logger.info(f"下载文件: {download_url}") file_response = send_request("GET", download_url, timeout=60) if not file_response or file_response.status_code != 200: logger.error(f"文件下载失败: 状态码={getattr(file_response, 'status_code', '无响应')}") return None, None # 保存文件到本地 save_dir = JSON_PATHS[chip][group] os.makedirs(save_dir, exist_ok=True) save_file_name = f"{model_name}_{file_name}" save_path = os.path.join(save_dir, save_file_name) try: with open(save_path, 'wb') as f: f.write(file_response.content) logger.info(f"文件保存成功: {save_path}") # 尝试解析文件内容 try: with open(save_path, 'r', encoding='utf-8') as f: file_data = json.load(f) return save_path, file_data except Exception as e: logger.error(f"解析下载的JSON文件失败: {save_path}, 错误: {str(e)}") return save_path, None except Exception as e: logger.error(f"保存文件失败: {save_path}, 错误: {str(e)}") return None, None def download_files_with_build_ids(build_id_map, chip, group): """使用提取的build_id下载所有文件""" downloaded_files = [] download_errors = [] for model_name, build_id in build_id_map.items(): if not build_id: logger.warning(f"跳过缺少build_id的文件: {model_name}") downloaded_files.append({ "model": model_name, "status": "skipped", "reason": "missing_build_id" }) continue logger.info(f"下载文件: {model_name} (build_id={build_id})") save_path, file_data = download_file_with_build_id(model_name, chip, group, build_id) if save_path: if file_data: status = "success" else: status = "partial_success" download_errors.append({ "model": model_name, "error": "文件解析失败" }) downloaded_files.append({ "model": model_name, "status": status }) else: logger.warning(f"下载失败: {model_name}") download_errors.append({ "model": model_name, "error": "下载失败" }) downloaded_files.append({ "model": model_name, "status": "failed" }) return downloaded_files, download_errors def extract_om_and_mean_ddr(data): """ 从JSON数据中提取.om后缀键的值和mean_ddr值 :param data: JSON数据字典 :return: 元组 (om_value, mean_ddr_value) """ om_value = None mean_ddr = None # 查找所有以.om结尾的键并提取第一个值 om_keys = [key for key in data.keys() if key.endswith('.om')] if om_keys: om_value = data[om_keys[0]] # 查找mean_ddr值(直接查找或嵌套查找) if 'mean_ddr' in data: mean_ddr = data['mean_ddr'] else: # 在嵌套结构中查找mean_ddr for value in data.values(): if isinstance(value, dict) and 'mean_ddr' in value: mean_ddr = value['mean_ddr'] break return om_value, mean_ddr def parse_json_file(file_path, file_data=None): """解析JSON文件,提取性能数据""" if file_data is None: try: with open(file_path, 'r', encoding='utf-8') as f: file_data = json.load(f) except Exception as e: logger.error(f"读取文件出错 {file_path}: {e}") return None # 从文件名中提取模型名 file_name = os.path.basename(file_path) model_name = file_name.split('_')[0] # 提取.om值和mean_ddr值 om_value, mean_ddr = extract_om_and_mean_ddr(file_data) # 获取文件元数据 file_size = os.path.getsize(file_path) last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat() return { "model_name": model_name, "file_name": file_name, "file_path": file_path, "file_size": file_size, "last_modified": last_modified, "om_value": om_value, "mean_ddr": mean_ddr } def print_all_om_and_mean_ddr(): """打印所有设备/组的所有.om和mean_ddr值到终端""" print("\n" + "=" * 80) print("JSON性能数据提取报告 - .om值和mean_ddr") print("=" * 80) # 获取所有设备目录 devices = [d for d in os.listdir(DOWNLOAD_PATH) if os.path.isdir(os.path.join(DOWNLOAD_PATH, d))] if not devices: print(f"在 {DOWNLOAD_PATH} 中未找到设备目录") return total_files = 0 valid_files = 0 for device in devices: device_path = os.path.join(DOWNLOAD_PATH, device) # 获取设备下的所有组目录 groups = [g for g in os.listdir(device_path) if os.path.isdir(os.path.join(device_path, g))] if not groups: print(f"\n设备 {device} 中没有模型组") continue print(f"\n设备: {device} ({len(groups)}个模型组)") print("-" * 70) for group in groups: group_path = os.path.join(device_path, group) json_files = glob.glob(os.path.join(group_path, "*.json")) if not json_files: print(f" ├── 组 {group}: 没有JSON文件") continue print(f" ├── 组 {group}: {len(json_files)}个模型") # 打印组内所有模型的指标 for json_file in json_files: metrics = parse_json_file(json_file) filename = os.path.basename(json_file) om_str = f"{metrics['om_value']:.4f}" if metrics['om_value'] is not None else "N/A" ddr_str = f"{metrics['mean_ddr']:.4f}" if metrics['mean_ddr'] is not None else "N/A" print(f" │ ├── {filename[:30]:<30} | .om值: {om_str:<10} | mean_ddr: {ddr_str:<10}") total_files += 1 if metrics['om_value'] is not None and metrics['mean_ddr'] is not None: valid_files += 1 print("\n" + "=" * 80) print(f"扫描完成: 共处理 {total_files} 个文件, 有效数据 {valid_files} 个 ({valid_files/total_files*100:.1f}%)") print("=" * 80 + "\n") def get_performance_data(chip, group, refresh_data=False): """获取性能数据,refresh_data控制是否下载新文件""" performance_data = { "status": "success", "models": [], "timestamp": datetime.now().isoformat(), "chip_type": chip, "group": group, "json_path": JSON_PATHS[chip].get(group, "") if chip in JSON_PATHS else "", "track_path": TRACK_PATHS[chip].get(group, "") if chip in TRACK_PATHS else "", "file_count": 0, "build_id_count": 0, "refresh_performed": refresh_data, "download_errors": [], "downloaded_files": [], "build_missing": [] } # 1..track文件夹提取build_id build_id_map, build_missing = extract_build_ids_from_track(chip, group) performance_data["build_id_count"] = len(build_id_map) - len(build_missing) performance_data["build_missing"] = build_missing # 2. 检查路径有效性 if chip not in JSON_PATHS or group not in JSON_PATHS[chip]: performance_data["status"] = "error" performance_data["error"] = "无效芯片或模型组" return performance_data group_path = JSON_PATHS[chip][group] if not os.path.exists(group_path): performance_data["status"] = "error" performance_data["error"] = "JSON路径不存在" return performance_data # 确保目录存在 os.makedirs(group_path, exist_ok=True) # 3. 刷新时下载所有文件 if refresh_data: logger.info(f"开始刷新下载,共有{len(build_id_map)}个文件需要处理") downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group) performance_data["downloaded_files"] = downloaded_files performance_data["download_errors"] = download_errors performance_data["refresh_status"] = "executed" else: performance_data["refresh_status"] = "skipped" logger.info("跳过文件下载,仅使用本地现有数据") # 4. 处理所有JSON文件并提取指标 json_files = glob.glob(os.path.join(group_path, "*.json")) performance_data["file_count"] = len(json_files) for json_file in json_files: model_data = parse_json_file(json_file) if model_data: # 获取对应的build_id model_name = model_data["model_name"] build_id = build_id_map.get(model_name, "NA") model_data["build_id"] = build_id performance_data["models"].append(model_data) return performance_data @app.route('/api/refresh', methods=['POST']) def refresh_data_api(): """手动触发刷新操作API""" start_time = time.time() try: data = request.json chip = data.get('chip', 'Ascend610Lite') group = data.get('group', 'rl_nn') logger.info(f"手动刷新请求 - 芯片: {chip}, 组: {group}") # 1..track文件夹提取build_id build_id_map, build_missing = extract_build_ids_from_track(chip, group) # 2. 下载所有文件 downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group) # 3. 构建响应 response = { "status": "success", "chip": chip, "group": group, "downloaded_files": downloaded_files, "download_errors": download_errors, "build_missing": build_missing, "process_time": round(time.time() - start_time, 4) } return jsonify(response) except Exception as e: logger.exception("刷新数据时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/api/performance', methods=['GET']) def performance_api(): """性能数据API接口,添加refresh参数控制下载""" start_time = time.time() try: device = request.args.get('device', 'Ascend610Lite') group = request.args.get('type', 'rl_nn') refresh = request.args.get('refresh', 'false').lower() == 'true' logger.info(f"性能API请求 - 设备: {device}, 组: {group}, 刷新: {refresh}") performance_data = get_performance_data(device, group, refresh_data=refresh) process_time = time.time() - start_time response = jsonify({ **performance_data, "process_time": round(process_time, 4) }) response.headers['Cache-Control'] = 'public, max-age=300' return response except Exception as e: logger.exception("处理请求时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/') def home(): """首页路由""" return render_template('index.html') if __name__ == '__main__': # 确保下载目录存在 os.makedirs(DOWNLOAD_PATH, exist_ok=True) # 创建所有JSON子目录 for chip, groups in JSON_PATHS.items(): for group, path in groups.items(): os.makedirs(path, exist_ok=True) # 添加命令行参数解析 parser = argparse.ArgumentParser(description='AI芯片性能监控服务') parser.add_argument('--host', type=str, default='127.0.0.1', help='服务监听地址') parser.add_argument('--port', type=int, default=8080, help='服务监听端口') parser.add_argument('--print', action='store_true', help='启动时打印所有.om和mean_ddr值到终端') args = parser.parse_args() # 启动时打印所有.om和mean_ddr值到终端 if args.print: print_all_om_and_mean_ddr() app.run(host="127.0.0.1", port=8080, debug=True) 将提取的build_id和模型全名返回到前端
08-12
import os import json import requests from flask import Flask, jsonify, request, render_template from datetime import datetime import logging import glob import time import argparse import re app = Flask(__name__) # 配置日志 logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 定义基础路径 BASE_PATH = r"C:\Users\l30078648\Desktop\250730" DOWNLOAD_PATH = os.path.join(BASE_PATH, "downloads") # S3 OBS API配置 S3_OBS_API_URL = "https://adsci.api.ias.huawei.com/icollect/get_s3_obs_file_download_url/" S3_OBS_HEADERS = { "Accept-Language": "zh-CN", "Content-Type": "application/json" } USER_ID = "h00512356" # 芯片型号到文件后缀的映射 CHIP_FILENAME_MAP = { "BS9SX1A": "BS9SX1AA", # 添加额外的'A'后缀 "Ascend610Lite": "Ascend610Lite" } # 定义芯片和模型组路径映射 TRACK_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "prediction_nn") } } JSON_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "prediction_nn") } } def send_request(method, url, headers=None, data=None, timeout=30, max_try_times=1, **kwargs): """发送HTTP请求(带重试机制)""" try_times = 0 response = None if data is None: data = {} if 'files' in kwargs: request_body = data else: request_body = json.dumps(data) if not isinstance(data, str) else data.encode("utf-8") while try_times < max_try_times: try: response = requests.request( method=method, url=url, headers=headers, data=request_body, timeout=timeout, verify=False, **kwargs ) if response.status_code in (200, 201): break except requests.exceptions.Timeout: logger.warning(f"请求超时: url={url}") except Exception as e: logger.exception(f"请求失败: url={url}, 错误: {str(e)}") if response: logger.error(f"响应内容: {response.text}") try_times += 1 logger.warning(f"请求失败, url={url}, 重试中 ({try_times}/{max_try_times})") if max_try_times > 1: time.sleep(5) return response def extract_build_ids_from_track(chip, group): """从.track文件夹中提取build_id和完整模型名称数据""" build_data = {} build_missing = [] if chip not in TRACK_PATHS or group not in TRACK_PATHS[chip]: logger.error(f"无效路径: {chip}/{group}") return build_data, build_missing group_path = TRACK_PATHS[chip][group] if not os.path.exists(group_path): logger.error(f"原始路径不存在: {group_path}") return build_data, build_missing # 扫描所有JSON文件 json_files = glob.glob(os.path.join(group_path, "*.json")) logger.info(f"在路径 {group_path} 中找到 {len(json_files)} 个JSON文件") for json_file in json_files: try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) # 获取完整模型名称(带后缀) model_full_name = os.path.splitext(os.path.basename(json_file))[0] # 提取build_id build_id = data.get('build_id') # 处理不同格式的ID if isinstance(build_id, int): build_id = str(build_id) # 存储build_id和模型全名 build_data[model_full_name] = { "build_id": build_id, "model_full_name": model_full_name } if build_id: logger.debug(f"提取成功: {model_full_name} -> build_id: {build_id}") else: logger.warning(f"文件 {json_file} 中没有找到build_id字段") build_missing.append(model_full_name) except Exception as e: logger.error(f"解析原始文件 {json_file} 时出错: {str(e)}") build_missing.append(os.path.basename(json_file)) continue logger.info(f"成功提取 {len(build_data) - len(build_missing)} 个build_id") return build_data, build_missing def get_chip_filename(chip): """获取芯片对应的正确文件名""" # 使用映射表获取正确的文件名后缀 return CHIP_FILENAME_MAP.get(chip, f"{chip}") def download_file_with_build_id(model_full_name, chip, group, build_id): """使用build_id下载JSON文件""" if not build_id: logger.warning(f"跳过缺少build_id的文件: {model_full_name}") return None, None # 获取芯片对应的正确文件名 chip_filename = get_chip_filename(chip) # 构建S3 OBS路径 s3_obs_url = f"hpct/nn_model_prebuild/{build_id}_PREBUILD_OUTPUTS/profiling_summary/" file_name = f"{chip_filename}_prof_record.json" # 获取下载URL download_url, error = get_download_url(build_id, s3_obs_url, file_name) if not download_url: logger.error(f"获取下载URL失败: {error}") return None, None # 下载并保存文件 save_path, file_data = download_and_save(download_url, model_full_name, chip, group, file_name) return save_path, file_data def get_download_url(build_id, s3_obs_url, file_name): """获取S3 OBS下载URL""" payload = { "user_id": USER_ID, "s3_obs_url": s3_obs_url, "file_name": file_name } logger.info(f"请求下载URL: {payload}") response = send_request( "POST", S3_OBS_API_URL, headers=S3_OBS_HEADERS, data=payload ) if not response or response.status_code != 200: logger.error(f"获取下载URL失败: 状态码={getattr(response, 'status_code', '无响应')}") return None, "下载URL获取失败" try: result = response.json() download_url = result.get("download_url") if not download_url: logger.error(f"响应中缺少download_url字段: {result}") return None, "响应中缺少download_url字段" # 修复URL端口问题:确保使用10001端口 fixed_url = fix_download_url_port(download_url) return fixed_url, None except Exception as e: logger.error(f"解析响应失败: {str(e)}") return None, f"解析响应失败: {str(e)}" def fix_download_url_port(url): """修正下载URL的端口号""" # 使用正则表达式匹配并替换端口号 pattern = r"(\w+\.huawei\.com):\d+" replacement = r"\1:10001" fixed_url = re.sub(pattern, replacement, url) if fixed_url != url: logger.info(f"URL端口已修正: {url} -> {fixed_url}") return fixed_url def download_and_save(download_url, model_full_name, chip, group, file_name): """下载文件并保存到本地""" logger.info(f"下载文件: {download_url}") file_response = send_request("GET", download_url, timeout=60) if not file_response or file_response.status_code != 200: logger.error(f"文件下载失败: 状态码={getattr(file_response, 'status_code', '无响应')}") return None, None # 保存文件到本地 save_dir = JSON_PATHS[chip][group] os.makedirs(save_dir, exist_ok=True) save_file_name = f"{model_full_name}_{file_name}" save_path = os.path.join(save_dir, save_file_name) try: with open(save_path, 'wb') as f: f.write(file_response.content) logger.info(f"文件保存成功: {save_path}") # 尝试解析文件内容 try: with open(save_path, 'r', encoding='utf-8') as f: file_data = json.load(f) return save_path, file_data except Exception as e: logger.error(f"解析下载的JSON文件失败: {save_path}, 错误: {str(e)}") return save_path, None except Exception as e: logger.error(f"保存文件失败: {save_path}, 错误: {str(e)}") return None, None def download_files_with_build_ids(build_id_map, chip, group): """使用提取的build_id下载所有文件""" downloaded_files = [] download_errors = [] for model_full_name, build_info in build_id_map.items(): build_id = build_info["build_id"] if not build_id: logger.warning(f"跳过缺少build_id的文件: {model_full_name}") downloaded_files.append({ "model": model_full_name, "status": "skipped", "reason": "missing_build_id" }) continue logger.info(f"下载文件: {model_full_name} (build_id={build_id})") save_path, file_data = download_file_with_build_id(model_full_name, chip, group, build_id) if save_path: if file_data: status = "success" else: status = "partial_success" download_errors.append({ "model": model_full_name, "error": "文件解析失败" }) downloaded_files.append({ "model": model_full_name, "status": status }) else: logger.warning(f"下载失败: {model_full_name}") download_errors.append({ "model": model_full_name, "error": "下载失败" }) downloaded_files.append({ "model": model_full_name, "status": "failed" }) return downloaded_files, download_errors def extract_om_and_mean_ddr(data): """ 从JSON数据中提取.om后缀键的值和mean_ddr值 :param data: JSON数据字典 :return: 元组 (om_value, mean_ddr_value, bandwidth_value) """ om_value = None mean_ddr = None bandwidth = None # 查找所有以.om结尾的键并提取第一个值 om_keys = [key for key in data.keys() if key.endswith('.om')] if om_keys: om_value = data[om_keys[0]] # 查找mean_ddr值(直接查找或嵌套查找) if 'mean_ddr' in data: mean_ddr = data['mean_ddr'] else: # 在嵌套结构中查找mean_ddr for value in data.values(): if isinstance(value, dict) and 'mean_ddr' in value: mean_ddr = value['mean_ddr'] break # 计算带宽值 (带宽 = 时延 * mean_ddr * 10 / 1000) if om_value is not None and mean_ddr is not None: try: bandwidth = (float(om_value) * float(mean_ddr) * 10) / 1000 bandwidth = round(bandwidth, 4) # 保留4位小数 except (TypeError, ValueError): logger.error(f"带宽计算失败: om_value={om_value}, mean_ddr={mean_ddr}") return om_value, mean_ddr, bandwidth # 返回三个值 def parse_json_file(file_path, file_data=None): """解析JSON文件,提取性能数据和完整模型名称""" if file_data is None: try: with open(file_path, 'r', encoding='utf-8') as f: file_data = json.load(f) except Exception as e: logger.error(f"读取文件出错 {file_path}: {e}") return None # 从文件名中提取完整模型名称(带后缀) file_name = os.path.basename(file_path) # 文件名格式: "模型全名_芯片_prof_record.json" model_full_name = "_".join(file_name.split('_')[:-3]) # 去掉最后三个部分 # 提取.om值、mean_ddr值和带宽值 om_value, mean_ddr, bandwidth = extract_om_and_mean_ddr(file_data) # 获取文件元数据 file_size = os.path.getsize(file_path) last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat() return { "model_full_name": model_full_name, "file_name": file_name, "file_path": file_path, "file_size": file_size, "last_modified": last_modified, "om_value": om_value, "mean_ddr": mean_ddr, "bandwidth": bandwidth # 新增带宽字段 } def print_all_om_and_mean_ddr(): """打印所有设备/组的所有.om、mean_ddr和带宽值到终端""" print("\n" + "=" * 100) print("JSON性能数据提取报告 - 时延(ms)、DDR和带宽(Mbps)") print("=" * 100) # 获取所有设备目录 devices = [d for d in os.listdir(DOWNLOAD_PATH) if os.path.isdir(os.path.join(DOWNLOAD_PATH, d))] if not devices: print(f"在 {DOWNLOAD_PATH} 中未找到设备目录") return total_files = 0 valid_files = 0 for device in devices: device_path = os.path.join(DOWNLOAD_PATH, device) # 获取设备下的所有组目录 groups = [g for g in os.listdir(device_path) if os.path.isdir(os.path.join(device_path, g))] if not groups: print(f"\n设备 {device} 中没有模型组") continue print(f"\n设备: {device} ({len(groups)}个模型组)") print("-" * 90) for group in groups: group_path = os.path.join(device_path, group) json_files = glob.glob(os.path.join(group_path, "*.json")) if not json_files: print(f" ├── 组 {group}: 没有JSON文件") continue print(f" ├── 组 {group}: {len(json_files)}个模型") # 打印组内所有模型的指标 for json_file in json_files: metrics = parse_json_file(json_file) filename = os.path.basename(json_file) om_str = f"{metrics['om_value']:.4f}" if metrics['om_value'] is not None else "N/A" ddr_str = f"{metrics['mean_ddr']:.4f}" if metrics['mean_ddr'] is not None else "N/A" bw_str = f"{metrics['bandwidth']:.4f}" if metrics['bandwidth'] is not None else "N/A" print(f" │ ├── {metrics['model_full_name']} | 时延: {om_str:<8} | DDR: {ddr_str:<8} | 带宽: {bw_str:<8}") total_files += 1 if metrics['om_value'] is not None and metrics['mean_ddr'] is not None: valid_files += 1 print("\n" + "=" * 100) print(f"扫描完成: 共处理 {total_files} 个文件, 有效数据 {valid_files} 个 ({valid_files/total_files*100:.1f}%)") print("=" * 100 + "\n") def get_performance_data(chip, group, refresh_data=False): """获取性能数据,refresh_data控制是否下载新文件""" performance_data = { "status": "success", "models": [], "timestamp": datetime.now().isoformat(), "chip_type": chip, "group": group, "json_path": JSON_PATHS[chip].get(group, "") if chip in JSON_PATHS else "", "track_path": TRACK_PATHS[chip].get(group, "") if chip in TRACK_PATHS else "", "file_count": 0, "build_id_count": 0, "refresh_performed": refresh_data, "download_errors": [], "downloaded_files": [], "build_missing": [] } # 1..track文件夹提取build_id和模型全名 build_id_map, build_missing = extract_build_ids_from_track(chip, group) performance_data["build_id_count"] = len(build_id_map) - len(build_missing) performance_data["build_missing"] = build_missing # 2. 检查路径有效性 if chip not in JSON_PATHS or group not in JSON_PATHS[chip]: performance_data["status"] = "error" performance_data["error"] = "无效芯片或模型组" return performance_data group_path = JSON_PATHS[chip][group] if not os.path.exists(group_path): performance_data["status"] = "error" performance_data["error"] = "JSON路径不存在" return performance_data # 确保目录存在 os.makedirs(group_path, exist_ok=True) # 3. 刷新时下载所有文件 if refresh_data: logger.info(f"开始刷新下载,共有{len(build_id_map)}个文件需要处理") downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group) performance_data["downloaded_files"] = downloaded_files performance_data["download_errors"] = download_errors performance_data["refresh_status"] = "executed" else: performance_data["refresh_status"] = "skipped" logger.info("跳过文件下载,仅使用本地现有数据") # 4. 处理所有JSON文件并提取指标 json_files = glob.glob(os.path.join(group_path, "*.json")) performance_data["file_count"] = len(json_files) for json_file in json_files: model_data_parsed = parse_json_file(json_file) if model_data_parsed: model_full_name = model_data_parsed["model_full_name"] # 获取对应的build_id和完整模型名称 build_info = build_id_map.get(model_full_name, {}) build_id = build_info.get("build_id", "NA") # 构造前端需要的模型数据 model_for_frontend = { "model_full_name": model_full_name, "prebuild_id": build_id, "latency": model_data_parsed["om_value"], # 时延 "mean_ddr": model_data_parsed["mean_ddr"], # DDR值 "bandwidth": model_data_parsed["bandwidth"], # 带宽值(新增) "timestamp": model_data_parsed["last_modified"] } performance_data["models"].append(model_for_frontend) return performance_data @app.route('/api/refresh', methods=['POST']) def refresh_data_api(): """手动触发刷新操作API""" start_time = time.time() try: data = request.json chip = data.get('chip', 'Ascend610Lite') group = data.get('group', 'rl_nn') logger.info(f"手动刷新请求 - 芯片: {chip}, 组: {group}") # 1..track文件夹提取build_id build_id_map, build_missing = extract_build_ids_from_track(chip, group) # 2. 下载所有文件 downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group) # 3. 构建响应 response = { "status": "success", "chip": chip, "group": group, "downloaded_files": downloaded_files, "download_errors": download_errors, "build_missing": build_missing, "process_time": round(time.time() - start_time, 4) } return jsonify(response) except Exception as e: logger.exception("刷新数据时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/api/performance', methods=['GET']) def performance_api(): """性能数据API接口,添加refresh参数控制下载""" start_time = time.time() try: device = request.args.get('device', 'Ascend610Lite') group = request.args.get('type', 'rl_nn') refresh = request.args.get('refresh', 'false').lower() == 'true' logger.info(f"性能API请求 - 设备: {device}, 组: {group}, 刷新: {refresh}") performance_data = get_performance_data(device, group, refresh_data=refresh) process_time = time.time() - start_time response = jsonify({ **performance_data, "process_time": round(process_time, 4) }) response.headers['Cache-Control'] = 'public, max-age=300' return response except Exception as e: logger.exception("处理请求时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/') def home(): """首页路由""" return render_template('index.html') if __name__ == '__main__': # 确保下载目录存在 os.makedirs(DOWNLOAD_PATH, exist_ok=True) # 创建所有JSON子目录 for chip, groups in JSON_PATHS.items(): for group, path in groups.items(): os.makedirs(path, exist_ok=True) # 添加命令行参数解析 parser = argparse.ArgumentParser(description='AI芯片性能监控服务') parser.add_argument('--host', type=str, default='127.0.0.1', help='服务监听地址') parser.add_argument('--port', type=int, default=8080, help='服务监听端口') parser.add_argument('--print', action='store_true', help='启动时打印所有.om和mean_ddr值到终端') args = parser.parse_args() # 启动时打印所有.om和mean_ddr值到终端 if args.print: print_all_om_and_mean_ddr() app.run(host="127.0.0.1", port=8080, debug=True) 更改此脚本来满足我的要求
最新发布
08-19
import os import json import requests from flask import Flask, jsonify, request, render_template from datetime import datetime import logging import glob import time import argparse app = Flask(__name__) # 配置日志 logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 定义基础路径 BASE_PATH = r"C:\Users\l30078648\Desktop\250730" DOWNLOAD_PATH = os.path.join(BASE_PATH, "downloads") # S3 OBS API配置 S3_OBS_API_URL = "https://adsci.api.ias.huawei.com/icollect/get_s3_obs_file_download_url/" S3_OBS_HEADERS = { "Accept-Language": "zh-CN", "Content-Type": "application/json" } USER_ID = "h00512356" # 定义芯片和模型组路径映射 TRACK_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "prediction_nn") } } JSON_PATHS = { "Ascend610Lite": { "rl_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "prediction_nn") }, "BS9SX1A": { "rl_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rl_nn"), "rsc_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rsc_nn"), "prediction_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "prediction_nn") } } def send_request(method, url, headers=None, data=None, timeout=30, max_try_times=1, **kwargs): """发送HTTP请求(带重试机制)""" try_times = 0 response = None if data is None: data = {} if 'files' in kwargs: request_body = data else: request_body = json.dumps(data) if not isinstance(data, str) else data.encode("utf-8") while try_times < max_try_times: try: response = requests.request( method=method, url=url, headers=headers, data=request_body, timeout=timeout, verify=False, **kwargs ) if response.status_code in (200, 201): break except requests.exceptions.Timeout: logger.warning(f"请求超时: url={url}") except Exception as e: logger.exception(f"请求失败: url={url}, 错误: {str(e)}") if response: logger.error(f"响应内容: {response.text}") try_times += 1 logger.warning(f"请求失败, url={url}, 重试中 ({try_times}/{max_try_times})") if max_try_times > 1: time.sleep(5) return response def download_file_from_s3_obs(prebuild_id, model_name, chip, group): """从S3 OBS下载JSON文件并返回解析后的数据""" # 构建S3 OBS路径 s3_obs_url = f"hpct/nn_model_prebuild/{prebuild_id}_PREBUILD_OUTPUTS/profiling_summary/" file_name = f"{chip}_prof_record.json" payload = { "user_id": USER_ID, "s3_obs_url": s3_obs_url, "file_name": file_name } logger.info(f"请求S3 OBS下载URL: {payload}") # 获取下载URL response = send_request( "POST", S3_OBS_API_URL, headers=S3_OBS_HEADERS, data=payload ) if not response or response.status_code != 200: logger.error(f"获取下载URL失败: 状态码={getattr(response, 'status_code', '无响应')}") return None, None try: result = response.json() download_url = result.get("download_url") if not download_url: logger.error(f"响应中缺少download_url字段: {result}") return None, None except Exception as e: logger.error(f"解析响应失败: {str(e)}") return None, None logger.info(f"下载文件: {download_url}") file_response = send_request("GET", download_url, timeout=60) if not file_response or file_response.status_code != 200: logger.error(f"文件下载失败: 状态码={getattr(file_response, 'status_code', '无响应')}") return None, None # 保存文件到本地 save_dir = JSON_PATHS[chip][group] os.makedirs(save_dir, exist_ok=True) save_file_name = f"{model_name}_{file_name}" save_path = os.path.join(save_dir, save_file_name) with open(save_path, 'wb') as f: f.write(file_response.content) logger.info(f"文件保存成功: {save_path}") # 下载后立即解析文件内容并提取关键指标 try: with open(save_path, 'r', encoding='utf-8') as f: file_data = json.load(f) # 提取.om值和mean_ddr值 om_value, mean_ddr = extract_om_and_mean_ddr(file_data) # 打印提取结果到终端 print(f"文件: {save_file_name}") print(f" .om值: {om_value if om_value is not None else '未找到'}") print(f" mean_ddr: {mean_ddr if mean_ddr is not None else '未找到'}") print("-" * 50) return save_path, file_data except Exception as e: logger.error(f"解析下载的JSON文件失败: {save_path}, 错误: {str(e)}") return save_path, None def extract_om_and_mean_ddr(data): """ 从JSON数据中提取.om后缀键的值和mean_ddr值 :param data: JSON数据字典 :return: 元组 (om_value, mean_ddr_value) """ om_value = None mean_ddr = None # 查找所有以.om结尾的键并提取第一个值 om_keys = [key for key in data.keys() if key.endswith('.om')] if om_keys: om_value = data[om_keys[0]] # 查找mean_ddr值(直接查找或嵌套查找) if 'mean_ddr' in data: mean_ddr = data['mean_ddr'] else: # 在嵌套结构中查找mean_ddr for value in data.values(): if isinstance(value, dict) and 'mean_ddr' in value: mean_ddr = value['mean_ddr'] break return om_value, mean_ddr def extract_prebuild_ids_from_track(chip, group): """从.track文件夹中提取prebuild_id数据""" prebuild_data = {} prebuild_missing = [] if chip not in TRACK_PATHS or group not in TRACK_PATHS[chip]: logger.error(f"无效路径: {chip}/{group}") return prebuild_data, prebuild_missing group_path = TRACK_PATHS[chip][group] if not os.path.exists(group_path): logger.error(f"原始路径不存在: {group_path}") return prebuild_data, prebuild_missing # 扫描所有JSON文件 json_files = glob.glob(os.path.join(group_path, "*.json")) logger.info(f"在路径 {group_path} 中找到 {len(json_files)} 个JSON文件") for json_file in json_files: try: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) model_name = os.path.splitext(os.path.basename(json_file))[0] prebuild_id = data.get('prebuild_id') # 处理不同格式的prebuild_id if isinstance(prebuild_id, int): prebuild_id = str(prebuild_id) if prebuild_id: prebuild_data[model_name] = prebuild_id logger.debug(f"提取成功: {model_name} -> {prebuild_id}") else: logger.warning(f"文件 {json_file} 中没有找到prebuild_id字段") prebuild_data[model_name] = None prebuild_missing.append(model_name) except Exception as e: logger.error(f"解析原始文件 {json_file} 时出错: {str(e)}") prebuild_missing.append(os.path.basename(json_file)) continue logger.info(f"成功提取 {len(prebuild_data) - len(prebuild_missing)} 个prebuild_id") return prebuild_data, prebuild_missing def download_files_using_prebuild_ids(prebuild_id_map, chip, group): """使用提取的prebuild_id下载所有文件""" downloaded_files = [] download_errors = [] downloaded_data = {} for model_name, prebuild_id in prebuild_id_map.items(): if prebuild_id is None: logger.warning(f"跳过缺少prebuild_id的文件: {model_name}") downloaded_files.append({ "model": model_name, "status": "skipped", "reason": "missing_prebuild_id" }) else: logger.info(f"下载文件: {model_name} (prebuild_id={prebuild_id})") save_path, file_data = download_file_from_s3_obs(prebuild_id, model_name, chip, group) if save_path: if file_data: status = "success" else: status = "partial_success" download_errors.append({ "model": model_name, "prebuild_id": prebuild_id, "error": "文件解析失败" }) downloaded_files.append({ "model": model_name, "status": status }) else: logger.warning(f"下载失败: {model_name}") download_errors.append({ "model": model_name, "prebuild_id": prebuild_id, "error": "下载失败" }) downloaded_files.append({ "model": model_name, "status": "failed" }) return downloaded_files, download_errors, downloaded_data def parse_json_file(file_path, file_data=None): """解析JSON文件,提取性能数据""" if file_data is None: try: with open(file_path, 'r', encoding='utf-8') as f: file_data = json.load(f) except Exception as e: logger.error(f"读取文件出错 {file_path}: {e}") return None # 从文件名中提取模型名 file_name = os.path.basename(file_path) model_name = file_name.split('_')[0] # 提取.om值和mean_ddr值 om_value, mean_ddr = extract_om_and_mean_ddr(file_data) # 获取文件元数据 file_size = os.path.getsize(file_path) last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat() return { "model_name": model_name, "file_name": file_name, "file_path": file_path, "file_size": file_size, "last_modified": last_modified, "om_value": om_value, "mean_ddr": mean_ddr } def print_all_om_and_mean_ddr(): """打印所有设备/组的所有.om和mean_ddr值到终端""" print("\n" + "=" * 80) print("JSON性能数据提取报告 - .om值和mean_ddr") print("=" * 80) # 获取所有设备目录 devices = [d for d in os.listdir(DOWNLOAD_PATH) if os.path.isdir(os.path.join(DOWNLOAD_PATH, d))] if not devices: print(f"在 {DOWNLOAD_PATH} 中未找到设备目录") return total_files = 0 valid_files = 0 for device in devices: device_path = os.path.join(DOWNLOAD_PATH, device) # 获取设备下的所有组目录 groups = [g for g in os.listdir(device_path) if os.path.isdir(os.path.join(device_path, g))] if not groups: print(f"\n设备 {device} 中没有模型组") continue print(f"\n设备: {device} ({len(groups)}个模型组)") print("-" * 70) for group in groups: group_path = os.path.join(device_path, group) json_files = glob.glob(os.path.join(group_path, "*.json")) if not json_files: print(f" ├── 组 {group}: 没有JSON文件") continue print(f" ├── 组 {group}: {len(json_files)}个模型") # 打印组内所有模型的指标 for json_file in json_files: metrics = parse_json_file(json_file) filename = os.path.basename(json_file) om_str = f"{metrics['om_value']:.4f}" if metrics['om_value'] is not None else "N/A" ddr_str = f"{metrics['mean_ddr']:.4f}" if metrics['mean_ddr'] is not None else "N/A" print(f" │ ├── {filename[:30]:<30} | .om值: {om_str:<10} | mean_ddr: {ddr_str:<10}") total_files += 1 if metrics['om_value'] is not None and metrics['mean_ddr'] is not None: valid_files += 1 print("\n" + "=" * 80) print(f"扫描完成: 共处理 {total_files} 个文件, 有效数据 {valid_files} 个 ({valid_files/total_files*100:.1f}%)") print("=" * 80 + "\n") def get_performance_data(chip, group, refresh_data=False): """获取性能数据,refresh_data控制是否下载新文件""" performance_data = { "status": "success", "models": [], "timestamp": datetime.now().isoformat(), "chip_type": chip, "group": group, "json_path": JSON_PATHS[chip].get(group, "") if chip in JSON_PATHS else "", "track_path": TRACK_PATHS[chip].get(group, "") if chip in TRACK_PATHS else "", "file_count": 0, "prebuild_id_count": 0, "refresh_performed": refresh_data, "download_errors": [], "downloaded_files": [], "prebuild_missing": [] } # 1..track文件夹提取prebuild_id prebuild_id_map, prebuild_missing = extract_prebuild_ids_from_track(chip, group) performance_data["prebuild_id_count"] = len(prebuild_id_map) - len(prebuild_missing) performance_data["prebuild_missing"] = prebuild_missing # 2. 检查路径有效性 if chip not in JSON_PATHS or group not in JSON_PATHS[chip]: performance_data["status"] = "error" performance_data["error"] = "无效芯片或模型组" return performance_data group_path = JSON_PATHS[chip][group] if not os.path.exists(group_path): performance_data["status"] = "error" performance_data["error"] = "JSON路径不存在" return performance_data # 确保目录存在 os.makedirs(group_path, exist_ok=True) # 3. 刷新时下载所有文件 if refresh_data: logger.info(f"开始刷新下载,共有{len(prebuild_id_map)}个文件需要处理") downloaded_files, download_errors, _ = download_files_using_prebuild_ids(prebuild_id_map, chip, group) performance_data["downloaded_files"] = downloaded_files performance_data["download_errors"] = download_errors performance_data["refresh_status"] = "executed" else: performance_data["refresh_status"] = "skipped" logger.info("跳过文件下载,仅使用本地现有数据") # 4. 处理所有JSON文件并提取指标 json_files = glob.glob(os.path.join(group_path, "*.json")) performance_data["file_count"] = len(json_files) for json_file in json_files: model_data = parse_json_file(json_file) if model_data: model_data["prebuild_id"] = prebuild_id_map.get(model_data["model_name"], "NA") performance_data["models"].append(model_data) return performance_data @app.route('/api/refresh', methods=['POST']) def refresh_data_api(): """手动触发刷新操作API""" start_time = time.time() try: data = request.json chip = data.get('chip', 'Ascend610Lite') group = data.get('group', 'rl_nn') logger.info(f"手动刷新请求 - 芯片: {chip}, 组: {group}") # 1..track文件夹提取prebuild_id prebuild_id_map, prebuild_missing = extract_prebuild_ids_from_track(chip, group) # 2. 下载所有文件并返回解析后的数据 downloaded_files, download_errors, _ = download_files_using_prebuild_ids(prebuild_id_map, chip, group) # 3. 构建响应 response = { "status": "success", "chip": chip, "group": group, "downloaded_files": downloaded_files, "download_errors": download_errors, "prebuild_missing": prebuild_missing, "process_time": round(time.time() - start_time, 4) } return jsonify(response) except Exception as e: logger.exception("刷新数据时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/api/performance', methods=['GET']) def performance_api(): """性能数据API接口,添加refresh参数控制下载""" start_time = time.time() try: device = request.args.get('device', 'Ascend610Lite') group = request.args.get('type', 'rl_nn') refresh = request.args.get('refresh', 'false').lower() == 'true' logger.info(f"性能API请求 - 设备: {device}, 组: {group}, 刷新: {refresh}") performance_data = get_performance_data(device, group, refresh_data=refresh) process_time = time.time() - start_time response = jsonify({ **performance_data, "process_time": round(process_time, 4) }) response.headers['Cache-Control'] = 'public, max-age=300' return response except Exception as e: logger.exception("处理请求时出错") return jsonify({ "status": "error", "error": "服务器内部错误", "details": str(e), "process_time": round(time.time() - start_time, 4) }), 500 @app.route('/') def home(): """首页路由""" return render_template('index.html') if __name__ == '__main__': # 确保下载目录存在 os.makedirs(DOWNLOAD_PATH, exist_ok=True) # 创建所有JSON子目录 for chip, groups in JSON_PATHS.items(): for group, path in groups.items(): os.makedirs(path, exist_ok=True) # 添加命令行参数解析 parser = argparse.ArgumentParser(description='AI芯片性能监控服务') parser.add_argument('--host', type=str, default='127.0.0.1', help='服务监听地址') parser.add_argument('--port', type=int, default=8080, help='服务监听端口') parser.add_argument('--print', action='store_true', help='启动时打印所有.om和mean_ddr值到终端') args = parser.parse_args() # 启动时打印所有.om和mean_ddr值到终端 if args.print: print_all_om_and_mean_ddr() app.run(host="127.0.0.1", port=8080, debug=True) 修改此脚本,如果下载失败就尝试提取build_id下载
08-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值