Executor类 -- _invoke方法

   Taskflow 执行器的核心方法 _invoke,负责执行单个任务节点并处理后续调度逻辑,

  1. 执行给定任务节点的工作

  2. 处理任务的前置和后置条件(如信号量获取/释放)

  3. 管理任务依赖关系

  4. 调度后续任务

// Procedure: _invoke
inline void Executor::_invoke(Worker& worker, Node* node) {

  #define TF_INVOKE_CONTINUATION()  \
  if (cache) {                      \
    node = cache;                   \
    goto begin_invoke;              \
  }

  begin_invoke:

  Node* cache {nullptr};
  
  // if this is the second invoke due to preemption, directly jump to invoke task
  if(node->_nstate & NSTATE::PREEMPTED) {
    goto invoke_task;
  }

  // if the work has been cancelled, there is no need to continue
  if(node->_is_cancelled()) {
    _tear_down_invoke(worker, node, cache);
    TF_INVOKE_CONTINUATION();
    return;
  }

  // if acquiring semaphore(s) exists, acquire them first
  if(node->_semaphores && !node->_semaphores->to_acquire.empty()) {
    SmallVector<Node*> waiters;
    if(!node->_acquire_all(waiters)) {
      _schedule(worker, waiters.begin(), waiters.end());
      return;
    }
  }
  
  invoke_task:
  
  SmallVector<int> conds;

  // switch is faster than nested if-else due to jump table
  switch(node->_handle.index()) {
    // static task
    case Node::STATIC:{
      _invoke_static_task(worker, node);
    }
    break;
    
    // runtime task
    case Node::RUNTIME:{
      if(_invoke_runtime_task(worker, node)) {
        return;
      }
    }
    break;

    // subflow task
    case Node::SUBFLOW: {
      if(_invoke_subflow_task(worker, node)) {
        return;
      }
    }
    break;

    // condition task
    case Node::CONDITION: {
      _invoke_condition_task(worker, node, conds);
    }
    break;

    // multi-condition task
    case Node::MULTI_CONDITION: {
      _invoke_multi_condition_task(worker, node, conds);
    }
    break;

    // module task
    case Node::MODULE: {
      if(_invoke_module_task(worker, node)) {
        return;
      }
    }
    break;

    // async task
    case Node::ASYNC: {
      if(_invoke_async_task(worker, node)) {
        return;
      }
      _tear_down_async(worker, node, cache);
      TF_INVOKE_CONTINUATION();
      return;
    }
    break;

    // dependent async task
    case Node::DEPENDENT_ASYNC: {
      if(_invoke_dependent_async_task(worker, node)) {
        return;
      }
      _tear_down_dependent_async(worker, node, cache);
      TF_INVOKE_CONTINUATION();
      return;
    }
    break;

    // monostate (placeholder)
    default:
    break;
  }

  // if releasing semaphores exist, release them
  if(node->_semaphores && !node->_semaphores->to_release.empty()) {
    SmallVector<Node*> waiters;
    node->_release_all(waiters);
    _schedule(worker, waiters.begin(), waiters.end());
  }

  // Reset the join counter with strong dependencies to support cycles.
  // + We must do this before scheduling the successors to avoid race
  //   condition on _predecessors.
  // + We must use fetch_add instead of direct assigning
  //   because the user-space call on "invoke" may explicitly schedule 
  //   this task again (e.g., pipeline) which can access the join_counter.
  node->_join_counter.fetch_add(
    node->num_predecessors() - (node->_nstate & ~NSTATE::MASK), std::memory_order_relaxed
  );

  // acquire the parent flow counter
  auto& join_counter = (node->_parent) ? node->_parent->_join_counter :
                       node->_topology->_join_counter;

  // Invoke the task based on the corresponding type
  switch(node->_handle.index()) {

    // condition and multi-condition tasks
    case Node::CONDITION:
    case Node::MULTI_CONDITION: {
      for(auto cond : conds) {
        if(cond >= 0 && static_cast<size_t>(cond) < node->_num_successors) {
          auto s = node->_edges[cond]; 
          // zeroing the join counter for invariant
          s->_join_counter.store(0, std::memory_order_relaxed);
          join_counter.fetch_add(1, std::memory_order_relaxed);
          _update_cache(worker, cache, s);
        }
      }
    }
    break;

    // non-condition task
    default: {
      for(size_t i=0; i<node->_num_successors; ++i) {
        if(auto s = node->_edges[i]; s->_join_counter.fetch_sub(1, std::memory_order_acq_rel) == 1) {
          join_counter.fetch_add(1, std::memory_order_relaxed);
          _update_cache(worker, cache, s);
        }
      }
    }
    break;
  }
  
  // clean up the node after execution
  _tear_down_invoke(worker, node, cache);
  TF_INVOKE_CONTINUATION();
} 

主要组成部分

1. 宏定义和初始化

#define TF_INVOKE_CONTINUATION()  \
  if (cache) {                    \
    node = cache;                 \
    goto begin_invoke;            \
  }

begin_invoke:
Node* cache {nullptr};
  • TF_INVOKE_CONTINUATION 宏用于实现任务执行的连续性

  • cache 变量用于存储下一个要执行的任务节点

2. 预处理检查

if(node->_nstate & NSTATE::PREEMPTED) {
  goto invoke_task;
}

if(node->_is_cancelled()) {
  _tear_down_invoke(worker, node, cache);
  TF_INVOKE_CONTINUATION();
  return;
}
  • 检查任务是否被抢占(PREEMPTED),如果是则直接执行

  • 检查任务是否被取消,如果是则清理并返回

3. 信号量获取

if(node->_semaphores && !node->_semaphores->to_acquire.empty()) {
  SmallVector<Node*> waiters;
  if(!node->_acquire_all(waiters)) {
    _schedule(worker, waiters.begin(), waiters.end());
    return;
  }
}
  • 如果任务需要获取信号量,尝试获取

  • 如果获取失败,调度等待的任务并返回

4. 任务执行(核心部分)

invoke_task:
switch(node->_handle.index()) {
  case Node::STATIC: { ... }
  case Node::RUNTIME: { ... }
  case Node::SUBFLOW: { ... }
  case Node::CONDITION: { ... }
  case Node::MULTI_CONDITION: { ... }
  case Node::MODULE: { ... }
  case Node::ASYNC: { ... }
  case Node::DEPENDENT_ASYNC: { ... }
  default: break;
}
  • 根据任务类型调用对应的执行方法

  • 处理不同类型的任务(静态、运行时、子流、条件等)

5. 信号量释放

if(node->_semaphores && !node->_semaphores->to_release.empty()) {
  SmallVector<Node*> waiters;
  node->_release_all(waiters);
  _schedule(worker, waiters.begin(), waiters.end());
}
  • 如果任务持有信号量,释放它们

  • 唤醒等待这些信号量的任务

6. 依赖关系管理

node->_join_counter.fetch_add(
  node->num_predecessors() - (node->_nstate & ~NSTATE::MASK), std::memory_order_relaxed
);
  • 重置任务的 join counter,支持循环依赖

  • 使用原子操作保证线程安全

7. 后继任务调度

switch(node->_handle.index()) {
  case Node::CONDITION:
  case Node::MULTI_CONDITION: {
    // 条件任务特殊处理
  }
  break;
  
  default: {
    // 普通任务处理
    for(size_t i=0; i<node->_num_successors; ++i) {
      if(auto s = node->_edges[i]; s->_join_counter.fetch_sub(1, std::memory_order_acq_rel) == 1) {
        join_counter.fetch_add(1, std::memory_order_relaxed);
        _update_cache(worker, cache, s);
      }
    }
  }
  break;
}
  • 对于条件任务,根据条件结果调度特定后继

  • 对于普通任务,减少所有后继的 join counter,当 counter 归零时调度

8. 清理和继续

_tear_down_invoke(worker, node, cache);
TF_INVOKE_CONTINUATION();
  • 清理当前任务

  • 如果有缓存的任务,继续执行

关键设计特点

  1. 状态管理

    • 使用 _nstate 标志位管理任务状态(如 PREEMPTED)

    • 原子计数器管理任务依赖

  2. 性能优化

    • 使用 switch 代替 if-else 实现跳转表

    • 使用 goto 减少函数调用开销

    • 内存顺序标记优化原子操作

  3. 任务类型支持

    • 支持多种任务类型(静态、动态、条件、异步等)

    • 每种类型有专门的处理逻辑

  4. 信号量支持

    • 任务可以声明需要获取/释放的信号量

    • 自动管理信号量等待队列

  5. 依赖关系处理

    • 支持普通依赖和条件依赖

    • 支持循环依赖图

执行流程总结

  1. 检查任务状态(抢占、取消)

  2. 获取所需信号量

  3. 根据任务类型执行任务

  4. 释放持有的信号量

  5. 更新依赖计数器

  6. 调度符合条件的后继任务

  7. 清理当前任务

  8. 如果有缓存任务,继续执行

这个方法是 Taskflow 高效任务调度的核心,通过精细的状态管理和优化实现了高性能的任务执行。

 

org.apache.flink.client.program.ProgramInvocationException: The main method caused an error: Could not deploy Yarn job cluster. at org.apache.flink.client.program.PackagedProgram.callMainMethod(PackagedProgram.java:366) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.client.program.PackagedProgram.invokeInteractiveModeForExecution(PackagedProgram.java:219) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.client.ClientUtils.executeProgram(ClientUtils.java:114) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.client.cli.CliFrontend.executeProgram(CliFrontend.java:842) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.client.cli.CliFrontend.run(CliFrontend.java:246) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.client.cli.CliFrontend.parseAndRun(CliFrontend.java:1084) ~[flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at com.huawei.bigdata.job.action.FlinkClient.lambda$flinkClientSubmit$0(FlinkClient.java:64) ~[executor-job-flink-1.0.jar:?] at java.security.AccessController.doPrivileged(Native Method) ~[?:1.8.0_372] at javax.security.auth.Subject.doAs(Subject.java:422) [?:1.8.0_372] at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1761) [flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at org.apache.flink.runtime.security.contexts.HadoopSecurityContext.runSecured(HadoopSecurityContext.java:41) [flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at com.huawei.bigdata.job.action.FlinkClient.flinkClientSubmit(FlinkClient.java:64) [executor-job-flink-1.0.jar:?] at com.huawei.bigdata.job.action.FlinkMain.runJob(FlinkMain.java:200) [executor-job-flink-1.0.jar:?] at com.huawei.bigdata.job.action.LauncherMain.submit(LauncherMain.java:93) [executor-job-core-1.0.jar:?] at com.huawei.bigdata.job.action.LauncherMain.run(LauncherMain.java:49) [executor-job-core-1.0.jar:?] at com.huawei.bigdata.job.action.FlinkMain.main(FlinkMain.java:108) [executor-job-flink-1.0.jar:?] at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ~[?:1.8.0_372] at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) ~[?:1.8.0_372] at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) ~[?:1.8.0_372] at java.lang.reflect.Method.invoke(Method.java:498) ~[?:1.8.0_372] at com.huawei.bigdata.job.LauncherAM.run(LauncherAM.java:126) [executor-job-core-1.0.jar:?] at com.huawei.bigdata.job.LauncherAM$1.run(LauncherAM.java:105) [executor-job-core-1.0.jar:?] at java.security.AccessController.doPrivileged(Native Method) ~[?:1.8.0_372] at javax.security.auth.Subject.doAs(Subject.java:422) [?:1.8.0_372] at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1761) [flink-dist_2.11-1.12.2-hw-ei-312091.jar:1.12.2-hw-ei-312091] at com.huawei.bigdata.job.LauncherAM.main(LauncherAM.java:101) [executor-job-core-1.0.jar:?] Caused by: org.apache.flink.client.deployment.ClusterDeploymentException: Could not deploy Yarn job cluster.
06-18
############################################################################### # Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking # email: lipku@foxmail.com # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # server.py from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets import base64 import json #import gevent #from gevent import pywsgi #from geventwebsocket.handler import WebSocketHandler import re import numpy as np from threading import Thread,Event #import multiprocessing import torch.multiprocessing as mp from aiohttp import web import aiohttp import aiohttp_cors from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.rtcrtpsender import RTCRtpSender from webrtc import HumanPlayer from basereal import BaseReal from llm import llm_response import argparse import random import shutil import asyncio import torch from typing import Dict from logger import logger import torch import time import os # 添加这行到文件顶部的其他import语句附近 app = Flask(__name__) #sockets = Sockets(app) nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal opt = None model = None avatar = None #####webrtc############################### pcs = set() def randN(N)->int: '''生成长度为 N的随机数 ''' min = pow(10, N - 1) max = pow(10, N) return random.randint(min, max - 1) def build_nerfreal(sessionid:int)->BaseReal: opt.sessionid=sessionid if opt.model == 'wav2lip': from lipreal import LipReal nerfreal = LipReal(opt,model,avatar) elif opt.model == 'musetalk': from musereal import MuseReal nerfreal = MuseReal(opt,model,avatar) elif opt.model == 'ernerf': from nerfreal import NeRFReal nerfreal = NeRFReal(opt,model,avatar) elif opt.model == 'ultralight': from lightreal import LightReal nerfreal = LightReal(opt,model,avatar) return nerfreal #@app.route('/offer', methods=['POST']) async def offer(request): params = await request.json() offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) if len(nerfreals) >= opt.max_session: logger.info('reach max session') return web.Response( content_type="application/json", text=json.dumps({"code": -1, "msg": "Maximum sessions reached"}), status=503 # HTTP 503 Service Unavailable ) sessionid = randN(6) logger.info('sessionid=%d', sessionid) nerfreals[sessionid] = None nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal, sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) del nerfreals[sessionid] if pc.connectionState == "closed": pcs.discard(pc) del nerfreals[sessionid] player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) capabilities = RTCRtpSender.getCapabilities("video") preferences = list(filter(lambda x: x.name == "H264", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "VP8", capabilities.codecs)) preferences += list(filter(lambda x: x.name == "rtx", capabilities.codecs)) transceiver = pc.getTransceivers()[1] transceiver.setCodecPreferences(preferences) await pc.setRemoteDescription(offer) answer = await pc.createAnswer() await pc.setLocalDescription(answer) return web.Response( content_type="application/json", text=json.dumps( {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "sessionid": sessionid} ), ) async def human(request): params = await request.json() sessionid = params.get('sessionid',0) if params.get('interrupt'): nerfreals[sessionid].flush_talk() if params['type']=='echo': nerfreals[sessionid].put_msg_txt(params['text']) elif params['type']=='chat': res=await asyncio.get_event_loop().run_in_executor(None, llm_response, params['text'],nerfreals[sessionid]) #nerfreals[sessionid].put_msg_txt(res) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data":"ok"} ), ) async def humanaudio(request): try: form= await request.post() sessionid = int(form.get('sessionid',0)) fileobj = form["file"] filename=fileobj.filename filebytes=fileobj.file.read() nerfreals[sessionid].put_audio_file(filebytes) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "msg":"ok"} ), ) except Exception as e: return web.Response( content_type="application/json", text=json.dumps( {"code": -1, "msg":"err","data": ""+e.args[0]+""} ), ) async def set_audiotype(request): params = await request.json() sessionid = params.get('sessionid',0) nerfreals[sessionid].set_custom_state(params['audiotype'],params['reinit']) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data":"ok"} ), ) async def record(request): params = await request.json() sessionid = params.get('sessionid',0) if params['type']=='start_record': # nerfreals[sessionid].put_msg_txt(params['text']) nerfreals[sessionid].start_recording() elif params['type']=='end_record': nerfreals[sessionid].stop_recording() return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data":"ok"} ), ) async def is_speaking(request): params = await request.json() sessionid = params.get('sessionid',0) return web.Response( content_type="application/json", text=json.dumps( {"code": 0, "data": nerfreals[sessionid].is_speaking()} ), ) async def on_shutdown(app): # close peer connections coros = [pc.close() for pc in pcs] await asyncio.gather(*coros) pcs.clear() async def post(url,data): try: async with aiohttp.ClientSession() as session: async with session.post(url,data=data) as response: return await response.text() except aiohttp.ClientError as e: logger.info(f'Error: {e}') async def run(push_url,sessionid): nerfreal = await asyncio.get_event_loop().run_in_executor(None, build_nerfreal,sessionid) nerfreals[sessionid] = nerfreal pc = RTCPeerConnection() pcs.add(pc) @pc.on("connectionstatechange") async def on_connectionstatechange(): logger.info("Connection state is %s" % pc.connectionState) if pc.connectionState == "failed": await pc.close() pcs.discard(pc) player = HumanPlayer(nerfreals[sessionid]) audio_sender = pc.addTrack(player.audio) video_sender = pc.addTrack(player.video) await pc.setLocalDescription(await pc.createOffer()) answer = await post(push_url,pc.localDescription.sdp) await pc.setRemoteDescription(RTCSessionDescription(sdp=answer,type='answer')) ########################################## # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' if __name__ == '__main__': torch.cuda.set_device(0) # 指定使用第一块 GPU torch.set_default_tensor_type('torch.cuda.FloatTensor') # 默认张量型为 GPU ###device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ###torch.set_default_tensor_type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor) mp.set_start_method('spawn') parser = argparse.ArgumentParser() parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") parser.add_argument('--torso_imgs', type=str, default="", help="torso images path") parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use") parser.add_argument('--workspace', type=str, default='data/video') parser.add_argument('--seed', type=int, default=0) ### training options parser.add_argument('--ckpt', type=str, default='data/pretrained/ngp_kf.pth') # 在参数解析部分(约第 150 行)修改默认值: parser.add_argument('--num_rays', type=int, default=4096, help="减少每批光线数量") # 原值 65536 parser.add_argument('--batch_size', type=int, default=8, help="降低批大小") # 原值 16 parser.add_argument('--max_ray_batch', type=int, default=2048, help="避免推理时 OOM") # 原值 4096 ###parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") ###parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") ### loss set parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss") parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss") parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss") parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss") ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") parser.add_argument('--bg_img', type=str, default='white', help="background image") parser.add_argument('--fbg', action='store_true', help="frame-wise bg") parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") ### dataset options parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") # (the default value is for the fox dataset) parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") parser.add_argument('--init_lips', action='store_true', help="init lips region") parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...") parser.add_argument('--torso', action='store_true', help="fix head and train torso") parser.add_argument('--head_ckpt', type=str, default='', help="head model") ### GUI options parser.add_argument('--gui', action='store_true', help="start a GUI") parser.add_argument('--W', type=int, default=450, help="GUI width") parser.add_argument('--H', type=int, default=450, help="GUI height") parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center") parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy") parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") ### else parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension") parser.add_argument('--part', action='store_true', help="use partial training data (1/10)") parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") # asr parser.add_argument('--asr', action='store_true', help="load asr for real-time app") parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") parser.add_argument('--asr_play', action='store_true', help="play out the audio") #parser.add_argument('--asr_model', type=str, default='deepspeech') parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') # parser.add_argument('--asr_model', type=str, default='facebook/hubert-large-ls960-ft') parser.add_argument('--asr_save_feats', action='store_true') # audio FPS parser.add_argument('--fps', type=int, default=50) # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--fullbody', action='store_true', help="fullbody human") parser.add_argument('--fullbody_img', type=str, default='data/fullbody/img') parser.add_argument('--fullbody_width', type=int, default=580) parser.add_argument('--fullbody_height', type=int, default=1080) parser.add_argument('--fullbody_offset_x', type=int, default=0) parser.add_argument('--fullbody_offset_y', type=int, default=0) #musetalk opt parser.add_argument('--avatar_id', type=str, default='avator_1') parser.add_argument('--bbox_shift', type=int, default=5) ###parser.add_argument('--batch_size', type=int, default=16) # parser.add_argument('--customvideo', action='store_true', help="custom video") # parser.add_argument('--customvideo_img', type=str, default='data/customvideo/img') # parser.add_argument('--customvideo_imgnum', type=int, default=1) parser.add_argument('--customvideo_config', type=str, default='') parser.add_argument('--tts', type=str, default='edgetts') #xtts gpt-sovits cosyvoice parser.add_argument('--REF_FILE', type=str, default=None) parser.add_argument('--REF_TEXT', type=str, default=None) parser.add_argument('--TTS_SERVER', type=str, default='http://127.0.0.1:9880') # http://localhost:9000 # parser.add_argument('--CHARACTER', type=str, default='test') # parser.add_argument('--EMOTION', type=str, default='default') parser.add_argument('--model', type=str, default='ernerf') #musetalk wav2lip parser.add_argument('--transport', type=str, default='rtcpush') #rtmp webrtc rtcpush parser.add_argument('--push_url', type=str, default='http://localhost:1985/rtc/v1/whip/?app=live&stream=livestream') #rtmp://localhost/live/livestream parser.add_argument('--max_session', type=int, default=100) #multi session count parser.add_argument('--listenport', type=int, default=8010) opt = parser.parse_args() #app.config.from_object(opt) #print(app.config) opt.customopt = [] if opt.customvideo_config!='': with open(opt.customvideo_config,'r') as file: opt.customopt = json.load(file) if opt.model == 'ernerf': from nerfreal import NeRFReal,load_model,load_avatar model = load_model(opt) avatar = load_avatar(opt) # we still need test_loader to provide audio features for testing. # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) # nerfreals.append(nerfreal) elif opt.model == 'musetalk': from musereal import MuseReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model() avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model) # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) # nerfreals.append(nerfreal) elif opt.model == 'wav2lip': from lipreal import LipReal,load_model,load_avatar,warm_up logger.info(opt) ###model = load_model("./models/wav2lip.pth") model = load_model("./models/wav2lip.pth").to('cuda') # 强制模型加载到 GPU ###model = load_model("./models/wav2lip.pth").to(device) # 动态适配 GPU/CPU avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,model,256) # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = LipReal(opt,model) # nerfreals.append(nerfreal) elif opt.model == 'ultralight': from lightreal import LightReal,load_model,load_avatar,warm_up logger.info(opt) model = load_model(opt) avatar = load_avatar(opt.avatar_id) warm_up(opt.batch_size,avatar,160) if opt.transport=='rtmp': thread_quit = Event() nerfreals[0] = build_nerfreal(0) rendthrd = Thread(target=nerfreals[0].render,args=(thread_quit,)) rendthrd.start() ############################################################################# appasync = web.Application() from aiohttp import WSMsgType async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) sessionid = request.query.get('sessionid', 0) if sessionid in nerfreals: nerfreals[sessionid].set_websocket(ws) async for msg in ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) # 处理可能的WebSocket消息 except json.JSONDecodeError: logger.error("无效的WebSocket消息格式") elif msg.type == WSMsgType.ERROR: logger.error(f"WebSocket错误: {ws.exception()}") if sessionid in nerfreals: nerfreals[sessionid].set_websocket(None) return ws async def get_system_reply(request): try: file_path = 'systemReply.txt' if not os.path.exists(file_path): logger.info('systemReply.txt 文件不存在') return web.Response( content_type="application/json", text=json.dumps({"text": ""}) ) # 只读取不清空文件 with open(file_path, 'r', encoding='utf-8') as f: content = f.read().strip() logger.info(f'从 systemReply.txt 读取内容: {content[:100]}...') # 只打印前100字符避免日志过长 return web.Response( content_type="application/json", text=json.dumps({"text": content}) ) except Exception as e: logger.error(f'读取 systemReply.txt 出错: {str(e)}') return web.Response( content_type="application/json", text=json.dumps({"error": str(e)}), status=500 ) async def clear_reply(request): try: params = await request.json() sessionid = params.get('sessionid', 0) file_path = 'systemReply.txt' # 清空文件内容 with open(file_path, 'w', encoding='utf-8') as f: f.write('') logger.info(f'已清空 systemReply.txt (会话ID: {sessionid})') return web.Response( content_type="application/json", text=json.dumps({"code": 0, "msg": "回复已清空"}) ) except Exception as e: logger.error(f'清空回复出错: {str(e)}') return web.Response( content_type="application/json", text=json.dumps({"code": -1, "error": str(e)}), status=500 ) async def get_system_reply_array(request): try: file_path = 'systemReplyArray.txt' if not os.path.exists(file_path): # 文件不存在时创建空文件 with open(file_path, 'w', encoding='utf-8') as f: f.write('') # 确保使用同步方式读取,避免异步问题 with open(file_path, 'r', encoding='utf-8') as f: content = f.read().strip() # 将换行符替换为 "||" content = content.replace('\n', '||') return web.Response( content_type="application/json", text=json.dumps({ "status": "success", "text": content, "timestamp": int(time.time()) # 添加时间戳防止缓存 }) ) except Exception as e: logger.error(f'读取systemReplyArray.txt出错: {str(e)}') return web.Response( content_type="application/json", status=500, text=json.dumps({ "status": "error", "error": str(e), "text": "" }) ) # 添加WebSocket路由 appasync.router.add_get("/ws", websocket_handler) appasync.router.add_post("/clear_reply", clear_reply) appasync.on_shutdown.append(on_shutdown) appasync.router.add_post("/offer", offer) appasync.router.add_post("/human", human) appasync.router.add_post("/humanaudio", humanaudio) appasync.router.add_post("/set_audiotype", set_audiotype) appasync.router.add_post("/record", record) appasync.router.add_post("/is_speaking", is_speaking) appasync.router.add_static('/',path='web') # 在appasync.router.add_...部分添加新路由 appasync.router.add_get("/get_system_reply", get_system_reply) appasync.router.add_get("/get_system_reply_array", get_system_reply_array) # Configure default CORS settings. cors = aiohttp_cors.setup(appasync, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", ) }) # Configure CORS on all routes. for route in list(appasync.router.routes()): cors.add(route) pagename='webrtcapi.html' if opt.transport=='rtmp': pagename='echoapi.html' elif opt.transport=='rtcpush': pagename='rtcpushapi.html' logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://127.0.0.1:'+str(opt.listenport)+'/ffnerchat.html') logger.info(f"模型使用的设备: {next(model.parameters()).device}") logger.info(f"当前 GPU 显存占用: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) site = web.TCPSite(runner, '127.0.0.1', opt.listenport) loop.run_until_complete(site.start()) # 添加打印可访问的URL import socket import webbrowser hostname = socket.gethostname() local_ip = socket.gethostbyname(hostname) logger.info(f"服务已启动,可通过以下地址访问:") logger.info(f"本地: http://127.0.0.1:{opt.listenport}/ffnerchatm.html") logger.info(f"局域网: http://{local_ip}:{opt.listenport}/ffnerchatm.html") url = f"http://127.0.0.1:{opt.listenport}/ffnerchatm.html" # 尝试打开浏览器 try: webbrowser.open(url) logger.info("已尝试在默认浏览器中打开页面") except Exception as e: logger.error(f"无法打开浏览器: {e}") if opt.transport=='rtcpush': for k in range(opt.max_session): push_url = opt.push_url if k!=0: push_url = opt.push_url+str(k) loop.run_until_complete(run(push_url,k)) loop.run_forever() #Thread(target=run_server, args=(web.AppRunner(appasync),)).start() run_server(web.AppRunner(appasync)) #app.on_shutdown.append(on_shutdown) #app.router.add_post("/offer", offer) # print('start websocket server') # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) # server.serve_forever() 这里有语音转文字吗
10-12
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值