<think>我们只需要修改下载函数,使其仅使用build_id进行下载。同时,我们需要调整ID提取函数,因为现在只需要build_id。
但是,为了保持代码的通用性,我们仍然提取两个ID,但在下载时只使用build_id。
因此,主要修改点在`download_file_from_s3_obs`函数:移除prebuild_id的尝试,直接使用build_id。
同时,在`download_files_using_prebuild_ids`函数中,我们只需要检查build_id是否存在。
修改步骤:
1. 在`extract_prebuild_ids_from_track`函数中,我们仍然提取两个ID,但后续只使用build_id。
2. 修改`download_file_from_s3_obs`函数,只使用build_id进行下载。
3. 修改`download_files_using_prebuild_ids`函数,只检查build_id是否存在。
注意:为了保持函数接口一致,我们保留参数名,但内部逻辑只使用build_id。
具体修改如下:</think>以下是根据要求修改的完整脚本,仅使用build_id进行文件下载:
```python
import os
import json
import requests
from flask import Flask, jsonify, request, render_template
from datetime import datetime
import logging
import glob
import time
import argparse
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 定义基础路径
BASE_PATH = r"C:\Users\l30078648\Desktop\250730"
DOWNLOAD_PATH = os.path.join(BASE_PATH, "downloads")
# S3 OBS API配置
S3_OBS_API_URL = "https://adsci.api.ias.huawei.com/icollect/get_s3_obs_file_download_url/"
S3_OBS_HEADERS = {
"Accept-Language": "zh-CN",
"Content-Type": "application/json"
}
USER_ID = "h00512356"
# 定义芯片和模型组路径映射
TRACK_PATHS = {
"Ascend610Lite": {
"rl_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, ".track", "Ascend610Lite", "prediction_nn")
},
"BS9SX1A": {
"rl_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rl_nn"),
"rsc_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "rsc_nn"),
"prediction_nn": os.path.join(BASE_PATH, ".track", "BS9SX1A", "prediction_nn")
}
}
JSON_PATHS = {
"Ascend610Lite": {
"rl_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rl_nn"),
"rsc_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "rsc_nn"),
"prediction_nn": os.path.join(DOWNLOAD_PATH, "Ascend610Lite", "prediction_nn")
},
"BS9SX1A": {
"rl_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rl_nn"),
"rsc_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "rsc_nn"),
"prediction_nn": os.path.join(DOWNLOAD_PATH, "BS9SX1A", "prediction_nn")
}
}
def send_request(method, url, headers=None, data=None, timeout=30, max_try_times=1, **kwargs):
"""发送HTTP请求(带重试机制)"""
try_times = 0
response = None
if data is None:
data = {}
if 'files' in kwargs:
request_body = data
else:
request_body = json.dumps(data) if not isinstance(data, str) else data.encode("utf-8")
while try_times < max_try_times:
try:
response = requests.request(
method=method,
url=url,
headers=headers,
data=request_body,
timeout=timeout,
verify=False,
**kwargs
)
if response.status_code in (200, 201):
break
except requests.exceptions.Timeout:
logger.warning(f"请求超时: url={url}")
except Exception as e:
logger.exception(f"请求失败: url={url}, 错误: {str(e)}")
if response:
logger.error(f"响应内容: {response.text}")
try_times += 1
logger.warning(f"请求失败, url={url}, 重试中 ({try_times}/{max_try_times})")
if max_try_times > 1:
time.sleep(5)
return response
def extract_build_ids_from_track(chip, group):
"""从.track文件夹中提取build_id数据"""
build_data = {}
build_missing = []
if chip not in TRACK_PATHS or group not in TRACK_PATHS[chip]:
logger.error(f"无效路径: {chip}/{group}")
return build_data, build_missing
group_path = TRACK_PATHS[chip][group]
if not os.path.exists(group_path):
logger.error(f"原始路径不存在: {group_path}")
return build_data, build_missing
# 扫描所有JSON文件
json_files = glob.glob(os.path.join(group_path, "*.json"))
logger.info(f"在路径 {group_path} 中找到 {len(json_files)} 个JSON文件")
for json_file in json_files:
try:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
model_name = os.path.splitext(os.path.basename(json_file))[0]
# 只提取build_id
build_id = data.get('build_id')
# 处理不同格式的ID
if isinstance(build_id, int):
build_id = str(build_id)
# 存储build_id
build_data[model_name] = build_id
if build_id:
logger.debug(f"提取成功: {model_name} -> build_id: {build_id}")
else:
logger.warning(f"文件 {json_file} 中没有找到build_id字段")
build_missing.append(model_name)
except Exception as e:
logger.error(f"解析原始文件 {json_file} 时出错: {str(e)}")
build_missing.append(os.path.basename(json_file))
continue
logger.info(f"成功提取 {len(build_data) - len(build_missing)} 个build_id")
return build_data, build_missing
def download_file_with_build_id(model_name, chip, group, build_id):
"""使用build_id下载JSON文件"""
if not build_id:
logger.warning(f"跳过缺少build_id的文件: {model_name}")
return None, None
# 构建S3 OBS路径
s3_obs_url = f"hpct/nn_model_prebuild/{build_id}_BUILD_OUTPUTS/profiling_summary/"
file_name = f"{chip}_prof_record.json"
# 获取下载URL
download_url, error = get_download_url(build_id, s3_obs_url, file_name)
if not download_url:
logger.error(f"获取下载URL失败: {error}")
return None, None
# 下载并保存文件
save_path, file_data = download_and_save(download_url, model_name, chip, group, file_name)
return save_path, file_data
def get_download_url(build_id, s3_obs_url, file_name):
"""获取S3 OBS下载URL"""
payload = {
"user_id": USER_ID,
"s3_obs_url": s3_obs_url,
"file_name": file_name
}
logger.info(f"请求下载URL: {payload}")
response = send_request(
"POST",
S3_OBS_API_URL,
headers=S3_OBS_HEADERS,
data=payload
)
if not response or response.status_code != 200:
logger.error(f"获取下载URL失败: 状态码={getattr(response, 'status_code', '无响应')}")
return None, "下载URL获取失败"
try:
result = response.json()
download_url = result.get("download_url")
if not download_url:
logger.error(f"响应中缺少download_url字段: {result}")
return None, "响应中缺少download_url字段"
return download_url, None
except Exception as e:
logger.error(f"解析响应失败: {str(e)}")
return None, f"解析响应失败: {str(e)}"
def download_and_save(download_url, model_name, chip, group, file_name):
"""下载文件并保存到本地"""
logger.info(f"下载文件: {download_url}")
file_response = send_request("GET", download_url, timeout=60)
if not file_response or file_response.status_code != 200:
logger.error(f"文件下载失败: 状态码={getattr(file_response, 'status_code', '无响应')}")
return None, None
# 保存文件到本地
save_dir = JSON_PATHS[chip][group]
os.makedirs(save_dir, exist_ok=True)
save_file_name = f"{model_name}_{file_name}"
save_path = os.path.join(save_dir, save_file_name)
try:
with open(save_path, 'wb') as f:
f.write(file_response.content)
logger.info(f"文件保存成功: {save_path}")
# 尝试解析文件内容
try:
with open(save_path, 'r', encoding='utf-8') as f:
file_data = json.load(f)
return save_path, file_data
except Exception as e:
logger.error(f"解析下载的JSON文件失败: {save_path}, 错误: {str(e)}")
return save_path, None
except Exception as e:
logger.error(f"保存文件失败: {save_path}, 错误: {str(e)}")
return None, None
def download_files_with_build_ids(build_id_map, chip, group):
"""使用提取的build_id下载所有文件"""
downloaded_files = []
download_errors = []
for model_name, build_id in build_id_map.items():
if not build_id:
logger.warning(f"跳过缺少build_id的文件: {model_name}")
downloaded_files.append({
"model": model_name,
"status": "skipped",
"reason": "missing_build_id"
})
continue
logger.info(f"下载文件: {model_name} (build_id={build_id})")
save_path, file_data = download_file_with_build_id(model_name, chip, group, build_id)
if save_path:
if file_data:
status = "success"
else:
status = "partial_success"
download_errors.append({
"model": model_name,
"error": "文件解析失败"
})
downloaded_files.append({
"model": model_name,
"status": status
})
else:
logger.warning(f"下载失败: {model_name}")
download_errors.append({
"model": model_name,
"error": "下载失败"
})
downloaded_files.append({
"model": model_name,
"status": "failed"
})
return downloaded_files, download_errors
def extract_om_and_mean_ddr(data):
"""
从JSON数据中提取.om后缀键的值和mean_ddr值
:param data: JSON数据字典
:return: 元组 (om_value, mean_ddr_value)
"""
om_value = None
mean_ddr = None
# 查找所有以.om结尾的键并提取第一个值
om_keys = [key for key in data.keys() if key.endswith('.om')]
if om_keys:
om_value = data[om_keys[0]]
# 查找mean_ddr值(直接查找或嵌套查找)
if 'mean_ddr' in data:
mean_ddr = data['mean_ddr']
else:
# 在嵌套结构中查找mean_ddr
for value in data.values():
if isinstance(value, dict) and 'mean_ddr' in value:
mean_ddr = value['mean_ddr']
break
return om_value, mean_ddr
def parse_json_file(file_path, file_data=None):
"""解析JSON文件,提取性能数据"""
if file_data is None:
try:
with open(file_path, 'r', encoding='utf-8') as f:
file_data = json.load(f)
except Exception as e:
logger.error(f"读取文件出错 {file_path}: {e}")
return None
# 从文件名中提取模型名
file_name = os.path.basename(file_path)
model_name = file_name.split('_')[0]
# 提取.om值和mean_ddr值
om_value, mean_ddr = extract_om_and_mean_ddr(file_data)
# 获取文件元数据
file_size = os.path.getsize(file_path)
last_modified = datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
return {
"model_name": model_name,
"file_name": file_name,
"file_path": file_path,
"file_size": file_size,
"last_modified": last_modified,
"om_value": om_value,
"mean_ddr": mean_ddr
}
def print_all_om_and_mean_ddr():
"""打印所有设备/组的所有.om和mean_ddr值到终端"""
print("\n" + "=" * 80)
print("JSON性能数据提取报告 - .om值和mean_ddr")
print("=" * 80)
# 获取所有设备目录
devices = [d for d in os.listdir(DOWNLOAD_PATH)
if os.path.isdir(os.path.join(DOWNLOAD_PATH, d))]
if not devices:
print(f"在 {DOWNLOAD_PATH} 中未找到设备目录")
return
total_files = 0
valid_files = 0
for device in devices:
device_path = os.path.join(DOWNLOAD_PATH, device)
# 获取设备下的所有组目录
groups = [g for g in os.listdir(device_path)
if os.path.isdir(os.path.join(device_path, g))]
if not groups:
print(f"\n设备 {device} 中没有模型组")
continue
print(f"\n设备: {device} ({len(groups)}个模型组)")
print("-" * 70)
for group in groups:
group_path = os.path.join(device_path, group)
json_files = glob.glob(os.path.join(group_path, "*.json"))
if not json_files:
print(f" ├── 组 {group}: 没有JSON文件")
continue
print(f" ├── 组 {group}: {len(json_files)}个模型")
# 打印组内所有模型的指标
for json_file in json_files:
metrics = parse_json_file(json_file)
filename = os.path.basename(json_file)
om_str = f"{metrics['om_value']:.4f}" if metrics['om_value'] is not None else "N/A"
ddr_str = f"{metrics['mean_ddr']:.4f}" if metrics['mean_ddr'] is not None else "N/A"
print(f" │ ├── {filename[:30]:<30} | .om值: {om_str:<10} | mean_ddr: {ddr_str:<10}")
total_files += 1
if metrics['om_value'] is not None and metrics['mean_ddr'] is not None:
valid_files += 1
print("\n" + "=" * 80)
print(f"扫描完成: 共处理 {total_files} 个文件, 有效数据 {valid_files} 个 ({valid_files/total_files*100:.1f}%)")
print("=" * 80 + "\n")
def get_performance_data(chip, group, refresh_data=False):
"""获取性能数据,refresh_data控制是否下载新文件"""
performance_data = {
"status": "success",
"models": [],
"timestamp": datetime.now().isoformat(),
"chip_type": chip,
"group": group,
"json_path": JSON_PATHS[chip].get(group, "") if chip in JSON_PATHS else "",
"track_path": TRACK_PATHS[chip].get(group, "") if chip in TRACK_PATHS else "",
"file_count": 0,
"build_id_count": 0,
"refresh_performed": refresh_data,
"download_errors": [],
"downloaded_files": [],
"build_missing": []
}
# 1. 从.track文件夹提取build_id
build_id_map, build_missing = extract_build_ids_from_track(chip, group)
performance_data["build_id_count"] = len(build_id_map) - len(build_missing)
performance_data["build_missing"] = build_missing
# 2. 检查路径有效性
if chip not in JSON_PATHS or group not in JSON_PATHS[chip]:
performance_data["status"] = "error"
performance_data["error"] = "无效芯片或模型组"
return performance_data
group_path = JSON_PATHS[chip][group]
if not os.path.exists(group_path):
performance_data["status"] = "error"
performance_data["error"] = "JSON路径不存在"
return performance_data
# 确保目录存在
os.makedirs(group_path, exist_ok=True)
# 3. 刷新时下载所有文件
if refresh_data:
logger.info(f"开始刷新下载,共有{len(build_id_map)}个文件需要处理")
downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group)
performance_data["downloaded_files"] = downloaded_files
performance_data["download_errors"] = download_errors
performance_data["refresh_status"] = "executed"
else:
performance_data["refresh_status"] = "skipped"
logger.info("跳过文件下载,仅使用本地现有数据")
# 4. 处理所有JSON文件并提取指标
json_files = glob.glob(os.path.join(group_path, "*.json"))
performance_data["file_count"] = len(json_files)
for json_file in json_files:
model_data = parse_json_file(json_file)
if model_data:
# 获取对应的build_id
model_name = model_data["model_name"]
build_id = build_id_map.get(model_name, "NA")
model_data["build_id"] = build_id
performance_data["models"].append(model_data)
return performance_data
@app.route('/api/refresh', methods=['POST'])
def refresh_data_api():
"""手动触发刷新操作API"""
start_time = time.time()
try:
data = request.json
chip = data.get('chip', 'Ascend610Lite')
group = data.get('group', 'rl_nn')
logger.info(f"手动刷新请求 - 芯片: {chip}, 组: {group}")
# 1. 从.track文件夹提取build_id
build_id_map, build_missing = extract_build_ids_from_track(chip, group)
# 2. 下载所有文件
downloaded_files, download_errors = download_files_with_build_ids(build_id_map, chip, group)
# 3. 构建响应
response = {
"status": "success",
"chip": chip,
"group": group,
"downloaded_files": downloaded_files,
"download_errors": download_errors,
"build_missing": build_missing,
"process_time": round(time.time() - start_time, 4)
}
return jsonify(response)
except Exception as e:
logger.exception("刷新数据时出错")
return jsonify({
"status": "error",
"error": "服务器内部错误",
"details": str(e),
"process_time": round(time.time() - start_time, 4)
}), 500
@app.route('/api/performance', methods=['GET'])
def performance_api():
"""性能数据API接口,添加refresh参数控制下载"""
start_time = time.time()
try:
device = request.args.get('device', 'Ascend610Lite')
group = request.args.get('type', 'rl_nn')
refresh = request.args.get('refresh', 'false').lower() == 'true'
logger.info(f"性能API请求 - 设备: {device}, 组: {group}, 刷新: {refresh}")
performance_data = get_performance_data(device, group, refresh_data=refresh)
process_time = time.time() - start_time
response = jsonify({
**performance_data,
"process_time": round(process_time, 4)
})
response.headers['Cache-Control'] = 'public, max-age=300'
return response
except Exception as e:
logger.exception("处理请求时出错")
return jsonify({
"status": "error",
"error": "服务器内部错误",
"details": str(e),
"process_time": round(time.time() - start_time, 4)
}), 500
@app.route('/')
def home():
"""首页路由"""
return render_template('index.html')
if __name__ == '__main__':
# 确保下载目录存在
os.makedirs(DOWNLOAD_PATH, exist_ok=True)
# 创建所有JSON子目录
for chip, groups in JSON_PATHS.items():
for group, path in groups.items():
os.makedirs(path, exist_ok=True)
# 添加命令行参数解析
parser = argparse.ArgumentParser(description='AI芯片性能监控服务')
parser.add_argument('--host', type=str, default='127.0.0.1', help='服务监听地址')
parser.add_argument('--port', type=int, default=8080, help='服务监听端口')
parser.add_argument('--print', action='store_true', help='启动时打印所有.om和mean_ddr值到终端')
args = parser.parse_args()
# 启动时打印所有.om和mean_ddr值到终端
if args.print:
print_all_om_and_mean_ddr()
app.run(host="127.0.0.1", port=8080, debug=True)
```
### 主要修改内容:
1. **ID提取简化**
- 重命名函数为`extract_build_ids_from_track`
- 只提取`build_id`字段,忽略`prebuild_id`
- 简化数据结构为字典{模型名: build_id}
2. **下载逻辑优化**
- 使用`download_file_with_build_id`函数专门处理build_id下载
- 构建特定路径:`{build_id}_BUILD_OUTPUTS/profiling_summary/`
- 文件名保持`{chip}_prof_record.json`格式
3. **错误处理精简**
- 移除prebuild_id相关错误处理
- 简化下载状态跟踪逻辑
- 保持详细的日志记录
4. **API响应调整**
- 性能数据中只包含build_id信息
- 刷新API返回build_id相关的下载状态
- 错误信息更精简明确
5. **变量命名更新**
- 所有相关变量名更新为build_id(如`build_id_map`, `build_missing`等)
- 函数名更准确反映功能(如`download_files_with_build_ids`)
这个脚本现在完全专注于使用build_id进行文件下载,移除了所有prebuild_id相关的逻辑,使代码更加简洁高效,同时保持了原有的错误处理和日志记录能力。