复制下面的内容以代码形式输出给我
import time
import uuid
import json
import os
import traceback
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from UTGen_server.model.api_task import ApiTask, BatchApiTask, UTTask, TaskStatus
from UTGen_server.queue.gencase_queue import UTGENTASKQUEUE
from UTGen_server.conf.setting import TESTCASE_FILE_NAME, setting
from UTGen_server.model.llvm_task import LLVMTask
from UTGen_server.model.prompt import SinglePrompt, AllPrompts
from UTGen_server.model.branch_design import BranchCoverDesign
from UTGen_server.utils.utils import seg_prompt, testcase_post_processing, read_file, read_json
from UTGen_server.workers.gen_branch_design import gen_branch_design
from UTGen_server.workers.gen_prompt import gen_prompt
from UTGen_server.workers.gen_case_by_llvm import llm_request_router, LlmRequest, qwen_chat_request
from UTGen_server.workers.compile_case import CompileProcessor
from UTGen_server.log.logs import logger, prompt_logger, case_logger, logger_path
from UTGen_server.track.code_tracker import track_workflow
setting.reload()
@track_workflow()
def get_testcase(request: LlmRequest, i: int, **kwargs):
begin = time.time()
router_ = f'{request.model}-{request.mode}'
test_string = json.loads(llm_request_router.get(router_, qwen_chat_request)(request))
case_logger.info(f'第{i + 1}个用例的模型返回结果:\n{test_string}')
try:
if request.mode == 'completion':
answer = test_string['choices'][0]['text']
else:
answer = test_string['choices'][0]['message']['content']
times = time.time() - begin
logger.info(f'已生成第{i + 1}个用例,用时:{times: .3f}s')
return answer
except Exception as e:
logger.error(f"errormsg: {str(e)}\n{traceback.format_exc()}")
return ''
def get_api_list_by_file(db_dir, apilist):
try:
symboldb = os.path.join(db_dir, "symbol.db")
utgendb = os.path.join(db_dir, "utgen.db")
apilite_new = []
def get_funclist_by_db(dbpath):
sc = sqlite3.connect(dbpath)
sc_cursor = sc.cursor()
sc_cursor.execute(sql)
results = sc_cursor.fetchall()
sc.close()
if results:
for result in results:
apilite_new.append(result[0])
for api in apilist:
if api.endswith(".h"):
sql = f'SELECT name from symbols where kind=12 AND DeclarationFileURI like "%{os.path.basename(api)}"'
get_funclist_by_db(symboldb)
elif api.endswith(".c"):
sql = f'SELECT name from func_jsons where file like "%{os.path.basename(api)}"'
get_funclist_by_db(utgendb)
else:
apilite_new.append(api)
except Exception as e:
logger.error(e)
return apilist
return apilite_new
class GenTestCase:
is_generating_llvm_task = False
is_generating_gencase_task = False
@staticmethod
def background_mode(apitask: ApiTask):
ret = "success"
if not apitask.constraint.db_dir_path:
apitask.constraint.db_dir_path = os.path.dirname(apitask.constraint.ccjson)
taskid = str(uuid.uuid4()).replace('-', '')[:32]
try:
UTGENTASKQUEUE.add_task(apitask.model_dump_json(), taskid)
except Exception as error:
ret = "failed"
UTGENTASKQUEUE.update_api_tasks_err_message(taskid, error)
logger.error(f"errormsg: {str(error)}\n{traceback.format_exc()}")
return ret
def background_mode_return_id(apitask: ApiTask):
ret = ""
if not apitask.constraint.db_dir_path:
apitask.constraint.db_dir_path = os.path.dirname(apitask.constraint.ccjson)
taskid = str(uuid.uuid4()).replace('-', '')[:32]
ret = taskid
try:
apitask.testcase_file_path = os.path.join(
apitask.testcase_file_path,
TESTCASE_FILE_NAME.format(apitask.api_name)
)
UTGENTASKQUEUE.add_task(apitask.model_dump_json(), taskid)
except Exception as error:
UTGENTASKQUEUE.update_api_tasks_err_message(taskid, error)
logger.error(f"errormsg: {str(error)}\n{traceback.format_exc()}")
return ret
@staticmethod
def batch_background_mode(batch_api_task: BatchApiTask):
ret = {"msg": "success", "log_path": logger_path}
taskid = str(uuid.uuid4()).replace('-', '')[:32]
try:
if not batch_api_task.constraint.db_dir_path:
batch_api_task.constraint.db_dir_path = os.path.dirname(batch_api_task.constraint.ccjson)
# 处理.c文件和.h文件中包含的函数列表信息
batch_api_task.api_name_list = get_api_list_by_file(
batch_api_task.constraint.db_dir_path, batch_api_task.api_name_list
)
for api_name in batch_api_task.api_name_list:
# 获取生成用例文件路径,并处理用例文件命名格式
case_name = api_name
for suffix in [".cpp", ".cc", ".c"]:
case_name = case_name.replace(suffix, "")
case_name = case_name.replace("::", "_").replace(" : ", "_")
test_case_file_path = os.path.join(
batch_api_task.testcase_output_dir,
TESTCASE_FILE_NAME.format(case_name.lower())
)
api_task = ApiTask(
user_name=batch_api_task.user_name,
api_name=api_name,
is_cpp=batch_api_task.is_cpp,
constraint=batch_api_task.constraint,
testcase_file_path=test_case_file_path
)
UTGENTASKQUEUE.add_task(api_task.model_dump_json(), taskid)
except Exception as error:
ret = "failed"
UTGENTASKQUEUE.update_api_tasks_err_message(taskid, error)
logger.error(f"errormsg: {str(error)}\n{traceback.format_exc()}")
return json.dumps(ret)
@staticmethod
def ut_gen_prompt(ut_task: UTTask):
logger.info('正在生成提示词...')
begin = time.time()
analysis_output = gen_branch_design(ut_task)
branch_design = BranchCoverDesign(
output_path=analysis_output[0],
branch_analysis=analysis_output[1],
code_analysis=analysis_output[2],
includes=analysis_output[3]
)
prompt_list = gen_prompt(branch_design)
output = []
branches = read_json(branch_design.branch_analysis)
for i, testcase in enumerate(prompt_list):
sign = str(uuid.uuid4()).replace('-', '')[:32]
prompt = {}
prompt['example'] = testcase.get('comment_macro', '')
prompt['uuid'] = sign
prompt['prompt'] = testcase.get('prompt', '')
output.append(prompt)
# 写高亮文件
try:
highlight = read_json(ut_task.constraint.high_light_path)
except (FileNotFoundError, json.JSONDecodeError):
highlight = {}
highlight[sign] = branches[i].get("line_cover", '').copy()
for _ in highlight[sign]:
_["functionName"] = _.pop("function_name")
_["path"] = _.pop("file_path")
with open(ut_task.constraint.high_light_path, 'w') as f:
json.dump(highlight, f, indent=2)
times = time.time() - begin
logger.info(f'已生成提示词,用时:{times: .3f}s')
return output
@staticmethod
def ut_gencase_for_single_branch(single_prompt: SinglePrompt,
model=setting.model, mode=setting.mode):
user_name = single_prompt.user_name
prompt = single_prompt.prompt
sign = single_prompt.uuid
prompt_logger.info(prompt)
llm_request = LlmRequest(
user_id=user_name,
prompt=prompt,
model=model,
mode=mode,
additional_args={'max_tokens': 2048}
)
logger.info('正在生成分支用例... ')
begin = time.time()
full_sign = f'// UTGen生成测试用例唯一标识符: ' + sign + '\n'
testcase = get_testcase(request=llm_request, i=0, session_id=sign)
answer = testcase_post_processing(testcase)
output = {'result': answer}
times = time.time() - begin
logger.info(f'已生成分支用例,用时:{times: .3f}s')
return output
@staticmethod
def ut_gencase_for_all_branches(all_prompts: AllPrompts,
model=setting.model, mode=setting.mode):
user_name = all_prompts.user_name
prompt_list = [_.prompt for _ in all_prompts.prompt_list]
sign_list = [_.uuid for _ in all_prompts.prompt_list]
with ThreadPoolExecutor(max_workers=3) as t:
answer_list = []
for i, prompt in enumerate(prompt_list):
llm_request = LlmRequest(
user_id=user_name,
prompt=prompt,
model=model,
mode=mode,
additional_args={'max_tokens': 2048}
)
prompt_logger.info(prompt)
kwargs = {
'request': llm_request,
'i': i,
'session_id': sign_list[i]
}
future = t.submit(get_testcase, **kwargs)
answer_list.append((future, i))
output = [{'data': '', 'uuid': ''} for _ in range(len(prompt_list))]
for future, index in answer_list:
output[index]['data'] = testcase_post_processing(future.result())
output[index]['uuid'] = sign_list[index]
logger.info(f'已生成全部分支用例')
return {'result': output}
@staticmethod
def start_gen_gencase_task():
seq_id, llvm_task, api_task_id = UTGENTASKQUEUE.get_llvm_pending_task()
try:
if not seq_id and GenTestCase.is_generating_gencase_task:
logger.info("生成用例任务执行完成")
logger.info('#所有任务执行完成#')
GenTestCase.is_generating_gencase_task = False
if not seq_id:
return
GenTestCase.is_generating_gencase_task = True
UTGENTASKQUEUE.update_llvm_task_status(seq_id, TaskStatus.pending_case)
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.pending_case)
current_api_task = UTGENTASKQUEUE.get_task_by_id(api_task_id)
if not current_api_task:
return
compile_task = CompileProcessor(llvm_task=llvm_task,
api_task=current_api_task)
compile_task.pre_writing()
compile_task.run()
logger.info(f'{current_api_task.api_name}用例写入完毕')
UTGENTASKQUEUE.update_llvm_task_status(seq_id, TaskStatus.end)
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.end)
except KeyboardInterrupt:
logger.info('用户中断执行')
except Exception as e:
UTGENTASKQUEUE.update_api_tasks_err_message(api_task_id, e)
UTGENTASKQUEUE.update_llvm_task_status(seq_id, TaskStatus.end)
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.end)
logger.error(f'生成用例失败: {str(e)}\n{traceback.format_exc()}')
@staticmethod
def start_gen_llvm_task():
api_task_id, apitask, taskid = UTGENTASKQUEUE.get_pending_task()
try:
if not apitask and GenTestCase.is_generating_llvm_task:
if GenTestCase.is_generating_gencase_task:
GenTestCase.is_generating_gencase_task = False
logger.info('生成llvm任务完成')
else:
logger.info('生成llvm任务完成')
logger.info('#所有任务执行完成#')
if apitask:
GenTestCase.is_generating_llvm_task = True
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.pending_prompt)
# 生成代码解析文件
analysis_output = gen_branch_design(apitask)
branch_design = BranchCoverDesign(
output_path=analysis_output[0],
branch_analysis=analysis_output[1],
code_analysis=analysis_output[2],
includes=analysis_output[3]
)
# 调用prompt生成
prompt_list = gen_prompt(branch_design)
llvmtask = LLVMTask(
user_name=apitask.user_name,
prompt_list=prompt_list,
task_id=taskid,
apitask_id=api_task_id,
includes=branch_design.includes,
branch_analysis=branch_design.branch_analysis
)
# 送入llvm队列
GenTestCase.is_generating_gencase_task = True
UTGENTASKQUEUE.add_llvm_task(json.dumps(llvmtask.model_dump(), indent=4), api_task_id)
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.pending_llvm)
else:
GenTestCase.is_generating_llvm_task = False
except KeyboardInterrupt:
logger.info('用户中断执行')
except Exception as e:
UTGENTASKQUEUE.update_api_tasks_err_message(api_task_id, e)
UTGENTASKQUEUE.update_task_status(api_task_id, TaskStatus.end)
logger.error(f'添加任务失败:{str(e)}\n{traceback.format_exc()}')
最新发布