今天,我们不仅要学会如何开发自定义节点,更要理解 Dify 节点系统的设计哲学,掌握从设计到测试的全流程开发技巧。
一、节点开发规范深度解析
1.1 节点架构的精髓
在深入代码之前,我们先来理解 Dify 节点系统的核心设计思想。
每个节点都继承自 BaseNode
,这不是简单的继承关系,而是精心设计的架构模式:
# 来源:api/core/workflow/nodes/base/node.py
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData] # 节点数据类型
_node_type: NodeType # 节点类型枚举
def __init__(self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
# 节点基础信息初始化
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
# ... 更多属性初始化
# 关键:节点数据验证与转换
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = node_data
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
"""核心执行方法 - 子类必须实现"""
raise NotImplementedError
设计亮点分析:
- 泛型约束:
Generic[GenericNodeData]
确保类型安全 - 状态隔离:每个节点实例拥有独立的运行时状态
- 配置验证:使用 Pydantic 进行配置校验,避免运行时错误
- 错误恢复:内置重试和错误处理机制
1.2 节点生命周期管理
理解节点的生命周期对开发至关重要:
# 节点执行的完整流程
def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
try:
# 1. 执行前置检查
result = self._run()
except Exception as e:
# 2. 异常处理
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
# 3. 结果处理和事件发送
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
else:
# 流式处理
yield from result
二、自定义逻辑节点实战开发
让我们通过一个实际的例子来学习如何开发自定义节点。我们要开发一个 数据验证节点,它能够验证输入数据的格式,并根据验证结果进行不同的处理。
2.1 定义节点数据结构
首先,我们需要定义节点的配置数据结构:
# entities.py
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
class ValidationRule(BaseModel):
"""单个验证规则"""
field_path: str = Field(..., description="字段路径,支持点号分隔")
rule_type: Literal["required", "type", "range", "regex", "custom"] = Field(..., description="规则类型")
expected_value: Any = Field(None, description="期望值或范围")
error_message: str = Field("", description="验证失败时的错误信息")
class DataValidationNodeData(BaseNodeData):
"""数据验证节点配置"""
# 输入变量选择器 - 要验证的数据
input_variable: str = Field(..., description="输入数据变量")
# 验证规则列表
validation_rules: List[ValidationRule] = Field(default_factory=list, description="验证规则列表")
# 验证模式
validation_mode: Literal["strict", "loose"] = Field(default="strict", description="验证模式")
# 输出配置
output_valid_data: bool = Field(default=True, description="是否输出有效数据")
output_errors: bool = Field(default=True, description="是否输出错误信息")
# 失败处理方式
on_failure: Literal["stop", "continue", "branch"] = Field(default="stop", description="验证失败时的处理方式")
设计思考:
- 使用 Pydantic 进行配置验证,确保运行时数据安全
- 支持多种验证规则类型,扩展性良好
- 灵活的失败处理策略,适应不同业务场景
2.2 实现核心验证逻辑
接下来实现节点的核心逻辑:
# data_validation_node.py
import re
import logging
from collections.abc import Mapping, Sequence
from typing import Any, Dict, List, Optional, Union
from jsonpath import JSONPathMatch
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import DataValidationNodeData, ValidationRule
logger = logging.getLogger(__name__)
class DataValidationNode(BaseNode[DataValidationNodeData]):
"""数据验证节点 - 验证输入数据的格式和内容"""
_node_data_cls = DataValidationNodeData
_node_type = NodeType.DATA_VALIDATION # 需要在枚举中添加
def _run(self) -> NodeRunResult:
"""执行数据验证"""
try:
# 1. 获取输入数据
input_data = self._get_input_data()
if input_data is None:
return self._create_error_result("输入数据为空")
# 2. 执行验证
validation_result = self._validate_data(input_data)
# 3. 处理验证结果
return self._process_validation_result(input_data, validation_result)
except Exception as e:
logger.exception(f"数据验证节点执行失败: {str(e)}")
return self._create_error_result(f"执行异常: {str(e)}")
def _get_input_data(self) -> Any:
"""获取输入数据"""
variable_selector = VariableSelector.model_validate(self.node_data.input_variable)
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
return variable.to_object() if variable else None
def _validate_data(self, data: Any) -> Dict[str, Any]:
"""执行数据验证"""
result = {
"is_valid": True,
"errors": [],
"warnings": [],
"validated_fields": []
}
for rule in self.node_data.validation_rules:
try:
field_result = self._validate_field(data, rule)
result["validated_fields"].append(field_result)
if not field_result["is_valid"]:
result["is_valid"] = False
result["errors"].append(field_result["error"])
except Exception as e:
logger.warning(f"验证规则执行失败: {rule.field_path} - {str(e)}")
if self.node_data.validation_mode == "strict":
result["is_valid"] = False
result["errors"].append(f"规则执行失败: {str(e)}")
else:
result["warnings"].append(f"规则执行失败: {str(e)}")
return result
def _validate_field(self, data: Any, rule: ValidationRule) -> Dict[str, Any]:
"""验证单个字段"""
field_result = {
"field_path": rule.field_path,
"rule_type": rule.rule_type,
"is_valid": True,
"error": None,
"actual_value": None
}
try:
# 获取字段值
field_value = self._get_field_value(data, rule.field_path)
field_result["actual_value"] = field_value
# 根据规则类型进行验证
is_valid = self._apply_validation_rule(field_value, rule)
if not is_valid:
field_result["is_valid"] = False
field_result["error"] = rule.error_message or f"{rule.field_path} 验证失败"
except Exception as e:
field_result["is_valid"] = False
field_result["error"] = f"字段访问失败: {str(e)}"
return field_result
def _get_field_value(self, data: Any, field_path: str) -> Any:
"""根据路径获取字段值 - 支持点号分隔和数组索引"""
if not field_path:
return data
current = data
parts = field_path.split('.')
for part in parts:
# 处理数组索引 如 items[0]
if '[' in part and ']' in part:
field_name = part.split('[')[0]
index_str = part.split('[')[1].split(']')[0]
if field_name:
current = current[field_name]
if index_str.isdigit():
current = current[int(index_str)]
else:
# 支持字典键访问
current = current[index_str]
else:
# 普通字段访问
if isinstance(current, dict):
current = current.get(part)
else:
current = getattr(current, part, None)
return current
def _apply_validation_rule(self, value: Any, rule: ValidationRule) -> bool:
"""应用验证规则"""
try:
if rule.rule_type == "required":
return value is not None and value != ""
elif rule.rule_type == "type":
expected_type = rule.expected_value
if expected_type == "string":
return isinstance(value, str)
elif expected_type == "number":
return isinstance(value, (int, float))
elif expected_type == "boolean":
return isinstance(value, bool)
elif expected_type == "array":
return isinstance(value, list)
elif expected_type == "object":
return isinstance(value, dict)
elif rule.rule_type == "range":
if not isinstance(value, (int, float)):
return False
range_config = rule.expected_value
min_val = range_config.get("min")
max_val = range_config.get("max")
if min_val is not None and value < min_val:
return False
if max_val is not None and value > max_val:
return False
return True
elif rule.rule_type == "regex":
if not isinstance(value, str):
return False
pattern = rule.expected_value
return bool(re.match(pattern, value))
elif rule.rule_type == "custom":
# 支持自定义 Python 表达式验证
# 注意:这里需要安全性考虑,实际使用时可能需要沙盒环境
custom_code = rule.expected_value
local_vars = {"value": value, "re": re}
return bool(eval(custom_code, {"__builtins__": {}}, local_vars))
return True
except Exception as e:
logger.warning(f"验证规则应用失败: {rule.rule_type} - {str(e)}")
return False
def _process_validation_result(self, input_data: Any, validation_result: Dict[str, Any]) -> NodeRunResult:
"""处理验证结果"""
outputs = {}
# 构建输出数据
if self.node_data.output_valid_data:
outputs["data"] = input_data if validation_result["is_valid"] else None
if self.node_data.output_errors:
outputs["errors"] = validation_result["errors"]
outputs["warnings"] = validation_result.get("warnings", [])
outputs["is_valid"] = validation_result["is_valid"]
outputs["validation_details"] = validation_result["validated_fields"]
# 根据验证结果和配置决定节点状态
if validation_result["is_valid"]:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs
)
else:
# 验证失败的处理
if self.node_data.on_failure == "stop":
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
error=f"数据验证失败: {'; '.join(validation_result['errors'])}"
)
else:
# continue 或 branch 模式下,节点成功执行但标记数据无效
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs
)
def _create_error_result(self, error_message: str) -> NodeRunResult:
"""创建错误结果"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
outputs={"is_valid": False, "errors": [error_message]}
)
核心设计亮点:
- 灵活的字段访问:支持点号分隔和数组索引,能处理复杂的嵌套数据结构
- 多种验证规则:从基础的类型检查到自定义表达式,满足各种验证需求
- 优雅的错误处理:区分警告和错误,支持严格模式和宽松模式
- 可配置的失败策略:停止执行、继续执行或分支处理,适应不同业务场景
2.3 添加节点配置界面
为了让节点能在 Dify 的可视化界面中使用,我们需要定义前端配置:
// web/app/components/workflow/nodes/data-validation/index.tsx
import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { RiAddLine, RiDeleteBinLine } from '@remixicon/react'
import type { DataValidationNodeData, ValidationRule } from './types'
import { NodeRunningStatus } from '@/app/components/workflow/types'
type Props = {
id: string
data: DataValidationNodeData
onChange: (data: DataValidationNodeData) => void
}
const DataValidationNode: React.FC<Props> = ({ id, data, onChange }) => {
const { t } = useTranslation()
const handleRuleChange = useCallback((index: number, rule: ValidationRule) => {
const newRules = [...data.validation_rules]
newRules[index] = rule
onChange({
...data,
validation_rules: newRules
})
}, [data, onChange])
const addRule = useCallback(() => {
const newRule: ValidationRule = {
field_path: '',
rule_type: 'required',
expected_value: null,
error_message: ''
}
onChange({
...data,
validation_rules: [...data.validation_rules, newRule]
})
}, [data, onChange])
const deleteRule = useCallback((index: number) => {
const newRules = data.validation_rules.filter((_, i) => i !== index)
onChange({
...data,
validation_rules: newRules
})
}, [data, onChange])
return (
<div className="space-y-4">
{/* 输入数据选择 */}
<div>
<label className="block text-sm font-medium text-gray-700 mb-2">
{t('workflow.nodes.dataValidation.inputData')}
</label>
<VariableSelector
value={data.input_variable}
onChange={(value) => onChange({ ...data, input_variable: value })}
placeholder={t('workflow.nodes.dataValidation.selectInputData')}
/>
</div>
{/* 验证规则配置 */}
<div>
<div className="flex items-center justify-between mb-3">
<label className="block text-sm font-medium text-gray-700">
{t('workflow.nodes.dataValidation.validationRules')}
</label>
<button
onClick={addRule}
className="flex items-center gap-1 px-2 py-1 text-sm bg-blue-500 text-white rounded hover:bg-blue-600"
>
<RiAddLine size={14} />
{t('workflow.nodes.dataValidation.addRule')}
</button>
</div>
<div className="space-y-3">
{data.validation_rules.map((rule, index) => (
<ValidationRuleConfig
key={index}
rule={rule}
onChange={(newRule) => handleRuleChange(index, newRule)}
onDelete={() => deleteRule(index)}
/>
))}
</div>
</div>
{/* 高级选项 */}
<div className="border-t pt-4">
<h4 className="text-sm font-medium text-gray-700 mb-3">
{t('workflow.nodes.dataValidation.advancedOptions')}
</h4>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm text-gray-600 mb-1">
{t('workflow.nodes.dataValidation.validationMode')}
</label>
<select
value={data.validation_mode}
onChange={(e) => onChange({ ...data, validation_mode: e.target.value as 'strict' | 'loose' })}
className="w-full px-3 py-1 border border-gray-300 rounded text-sm"
>
<option value="strict">{t('workflow.nodes.dataValidation.strictMode')}</option>
<option value="loose">{t('workflow.nodes.dataValidation.looseMode')}</option>
</select>
</div>
<div>
<label className="block text-sm text-gray-600 mb-1">
{t('workflow.nodes.dataValidation.onFailure')}
</label>
<select
value={data.on_failure}
onChange={(e) => onChange({ ...data, on_failure: e.target.value as 'stop' | 'continue' | 'branch' })}
className="w-full px-3 py-1 border border-gray-300 rounded text-sm"
>
<option value="stop">{t('workflow.nodes.dataValidation.stopOnFailure')}</option>
<option value="continue">{t('workflow.nodes.dataValidation.continueOnFailure')}</option>
<option value="branch">{t('workflow.nodes.dataValidation.branchOnFailure')}</option>
</select>
</div>
</div>
</div>
</div>
)
}
export default DataValidationNode
三、工具节点开发实践
除了逻辑处理节点,Dify 还支持工具类节点的开发。让我们看看如何开发一个实用的工具节点。
3.1 理解工具节点架构
工具节点的架构与普通节点有所不同,它更多地依赖 Dify 的工具系统:
# 参考:api/core/workflow/nodes/tool/tool_node.py
class ToolNode(BaseNode[ToolNodeData]):
"""工具节点 - 调用内置或自定义工具"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self) -> Generator[Union[NodeEvent, InNodeEvent], None, None]:
"""执行工具调用"""
# 1. 获取工具运行时实例
tool_runtime = self._get_tool_runtime()
# 2. 准备工具参数
parameters = self._generate_parameters(
tool_parameters=tool_runtime.get_merged_runtime_parameters(),
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
# 3. 调用工具执行
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_node_id=self.node_id,
)
# 4. 处理工具执行结果
yield from self._transform_message(message_stream, tool_info, parameters)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"工具调用失败: {str(e)}",
error_type=type(e).__name__,
)
)
3.2 开发自定义API调用工具
让我们开发一个强化版的 HTTP 请求工具,支持更丰富的功能:
# enhanced_http_tool.py
import json
import logging
import time
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin, urlparse
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import BaseToolEngine
from core.tools.errors import ToolInvokeError
logger = logging.getLogger(__name__)
class EnhancedHttpTool(BaseToolEngine):
"""增强版 HTTP 请求工具"""
def get_runtime_parameters(self) -> List[ToolParameter]:
"""定义工具参数"""
return [
ToolParameter.model_validate({
"name": "url",
"label": "请求URL",
"type": "string",
"required": True,
"form": "llm"
}),
ToolParameter.model_validate({
"name": "method",
"label": "HTTP方法",
"type": "select",
"required": True,
"default": "GET",
"options": ["GET", "POST", "PUT", "DELETE", "PATCH"],
"form": "form"
}),
ToolParameter.model_validate({
"name": "headers",
"label": "请求头",
"type": "string",
"required": False,
"form": "llm",
"llm_description": "JSON格式的请求头,如:{\"Content-Type\": \"application/json\"}"
}),
ToolParameter.model_validate({
"name": "body",
"label": "请求体",
"type": "string",
"required": False,
"form": "llm",
"llm_description": "请求体数据,支持JSON字符串或表单数据"
}),
ToolParameter.model_validate({
"name": "timeout",
"label": "超时时间(秒)",
"type": "number",
"required": False,
"default": 30,
"form": "form"
}),
ToolParameter.model_validate({
"name": "retry_times",
"label": "重试次数",
"type": "number",
"required": False,
"default": 3,
"form": "form"
}),
ToolParameter.model_validate({
"name": "follow_redirects",
"label": "跟随重定向",
"type": "boolean",
"required": False,
"default": True,
"form": "form"
})
]
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""执行HTTP请求"""
try:
# 1. 参数解析和验证
config = self._parse_parameters(tool_parameters)
# 2. 构建请求会话
session = self._create_session(config)
# 3. 执行请求
response_data = self._execute_request(session, config)
# 4. 处理响应
return self._process_response(response_data)
except Exception as e:
logger.exception(f"HTTP请求执行失败: {str(e)}")
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"HTTP请求失败: {str(e)}"
)
def _parse_parameters(self, tool_parameters: Dict[str, Any]) -> Dict[str, Any]:
"""解析和验证工具参数"""
config = {
"url": tool_parameters.get("url", "").strip(),
"method": tool_parameters.get("method", "GET").upper(),
"timeout": float(tool_parameters.get("timeout", 30)),
"retry_times": int(tool_parameters.get("retry_times", 3)),
"follow_redirects": bool(tool_parameters.get("follow_redirects", True)),
}
# URL验证
if not config["url"]:
raise ToolInvokeError("URL不能为空")
parsed_url = urlparse(config["url"])
if not parsed_url.scheme or not parsed_url.netloc:
raise ToolInvokeError("URL格式无效")
# 安全性检查 - 禁止访问内网地址
if self._is_internal_ip(parsed_url.hostname):
raise ToolInvokeError("禁止访问内网地址")
# 解析请求头
headers = tool_parameters.get("headers", "")
if headers:
try:
config["headers"] = json.loads(headers)
except json.JSONDecodeError:
raise ToolInvokeError("请求头格式无效,必须为有效的JSON")
else:
config["headers"] = {}
# 解析请求体
body = tool_parameters.get("body", "")
if body and config["method"] in ["POST", "PUT", "PATCH"]:
try:
# 尝试解析为JSON
config["json_body"] = json.loads(body)
config["headers"]["Content-Type"] = "application/json"
except json.JSONDecodeError:
# 当作普通文本处理
config["text_body"] = body
return config
def _is_internal_ip(self, hostname: str) -> bool:
"""检查是否为内网IP"""
if not hostname:
return False
# 简单的内网IP检查,实际使用时可能需要更完善的逻辑
import ipaddress
try:
ip = ipaddress.ip_address(hostname)
return ip.is_private or ip.is_loopback or ip.is_link_local
except ValueError:
# 不是IP地址,可能是域名
return hostname.lower() in ['localhost', '127.0.0.1', '::1']
def _create_session(self, config: Dict[str, Any]) -> requests.Session:
"""创建配置好的请求会话"""
session = requests.Session()
# 配置重试策略
retry_strategy = Retry(
total=config["retry_times"],
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["HEAD", "GET", "OPTIONS", "POST", "PUT", "DELETE", "PATCH"]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)
session.mount("https://", adapter)
# 设置默认请求头
session.headers.update({
"User-Agent": "Dify-EnhancedHttpTool/1.0",
**config.get("headers", {})
})
return session
def _execute_request(self, session: requests.Session, config: Dict[str, Any]) -> Dict[str, Any]:
"""执行HTTP请求"""
start_time = time.time()
request_kwargs = {
"method": config["method"],
"url": config["url"],
"timeout": config["timeout"],
"allow_redirects": config["follow_redirects"]
}
# 添加请求体
if "json_body" in config:
request_kwargs["json"] = config["json_body"]
elif "text_body" in config:
request_kwargs["data"] = config["text_body"]
try:
response = session.request(**request_kwargs)
execution_time = time.time() - start_time
# 构建响应数据
response_data = {
"status_code": response.status_code,
"headers": dict(response.headers),
"execution_time": round(execution_time, 3),
"url": response.url,
"request_method": config["method"],
"success": response.ok
}
# 处理响应体
content_type = response.headers.get("Content-Type", "").lower()
if "application/json" in content_type:
try:
response_data["json"] = response.json()
response_data["content_type"] = "json"
except ValueError:
response_data["text"] = response.text
response_data["content_type"] = "text"
else:
response_data["text"] = response.text
response_data["content_type"] = "text"
return response_data
except requests.exceptions.Timeout:
raise ToolInvokeError(f"请求超时({config['timeout']}秒)")
except requests.exceptions.ConnectionError:
raise ToolInvokeError("连接错误,请检查网络或URL")
except requests.exceptions.RequestException as e:
raise ToolInvokeError(f"请求失败: {str(e)}")
def _process_response(self, response_data: Dict[str, Any]) -> ToolInvokeMessage:
"""处理响应数据并格式化输出"""
# 构建友好的响应消息
status_emoji = "✅" if response_data["success"] else "❌"
message_parts = [
f"{status_emoji} **HTTP {response_data['request_method']} 请求完成**",
f"📍 **URL:** {response_data['url']}",
f"📊 **状态码:** {response_data['status_code']}",
f"⏱️ **执行时间:** {response_data['execution_time']}秒",
]
# 添加重要的响应头信息
important_headers = ["content-type", "content-length", "server", "date"]
headers_info = []
for header in important_headers:
if header in response_data["headers"]:
headers_info.append(f" • {header.title()}: {response_data['headers'][header]}")
if headers_info:
message_parts.append("📋 **响应头信息:**")
message_parts.extend(headers_info)
# 添加响应体
if response_data["content_type"] == "json":
message_parts.append("📄 **响应内容 (JSON):**")
json_str = json.dumps(response_data["json"], ensure_ascii=False, indent=2)
# 限制输出长度
if len(json_str) > 2000:
json_str = json_str[:2000] + "\n... (内容过长,已截断)"
message_parts.append(f"```json\n{json_str}\n```")
else:
text_content = response_data.get("text", "")
if text_content:
message_parts.append("📄 **响应内容:**")
if len(text_content) > 1000:
text_content = text_content[:1000] + "\n... (内容过长,已截断)"
message_parts.append(f"```\n{text_content}\n```")
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message="\n\n".join(message_parts)
)
四、节点测试方法详解
开发完节点后,测试是确保质量的关键环节。让我们看看如何进行全面的节点测试。
4.1 单元测试框架搭建
# tests/unit_tests/workflow/nodes/test_data_validation_node.py
import pytest
from unittest.mock import Mock, patch
from typing import Any, Dict
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.data_validation.data_validation_node import DataValidationNode
from core.workflow.nodes.data_validation.entities import DataValidationNodeData, ValidationRule
from models.workflow import WorkflowNodeExecutionStatus
class TestDataValidationNode:
"""数据验证节点测试类"""
@pytest.fixture
def mock_graph_runtime_state(self):
"""模拟图运行时状态"""
state = Mock(spec=GraphRuntimeState)
state.variable_pool = Mock(spec=VariablePool)
return state
@pytest.fixture
def basic_node_config(self):
"""基础节点配置"""
return {
"id": "test_validation_node",
"data": {
"input_variable": "input_data",
"validation_rules": [
{
"field_path": "name",
"rule_type": "required",
"expected_value": None,
"error_message": "姓名不能为空"
},
{
"field_path": "age",
"rule_type": "type",
"expected_value": "number",
"error_message": "年龄必须为数字"
}
],
"validation_mode": "strict",
"on_failure": "stop"
}
}
def test_successful_validation(self, mock_graph_runtime_state, basic_node_config):
"""测试验证成功的情况"""
# 准备测试数据
test_data = {"name": "张三", "age": 25}
# 模拟变量池返回数据
mock_variable = Mock()
mock_variable.to_object.return_value = test_data
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
# 创建节点实例
node = DataValidationNode(
id="test_node",
config=basic_node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=mock_graph_runtime_state
)
# 执行测试
result = node._run()
# 验证结果
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["is_valid"] is True
assert result.outputs["data"] == test_data
assert len(result.outputs["errors"]) == 0
def test_validation_failure(self, mock_graph_runtime_state, basic_node_config):
"""测试验证失败的情况"""
# 准备无效测试数据
test_data = {"name": "", "age": "not_a_number"}
mock_variable = Mock()
mock_variable.to_object.return_value = test_data
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
node = DataValidationNode(
id="test_node",
config=basic_node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=mock_graph_runtime_state
)
result = node._run()
# 验证失败结果
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.outputs["is_valid"] is False
assert len(result.outputs["errors"]) == 2
assert "姓名不能为空" in result.outputs["errors"]
def test_complex_field_path(self, mock_graph_runtime_state):
"""测试复杂字段路径访问"""
config = {
"id": "test_node",
"data": {
"input_variable": "input_data",
"validation_rules": [
{
"field_path": "user.profile.email",
"rule_type": "regex",
"expected_value": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
"error_message": "邮箱格式无效"
},
{
"field_path": "items[0].quantity",
"rule_type": "range",
"expected_value": {"min": 1, "max": 100},
"error_message": "数量必须在1-100之间"
}
]
}
}
test_data = {
"user": {
"profile": {
"email": "test@example.com"
}
},
"items": [
{"quantity": 50}
]
}
mock_variable = Mock()
mock_variable.to_object.return_value = test_data
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
node = DataValidationNode(
id="test_node",
config=config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=mock_graph_runtime_state
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["is_valid"] is True
@pytest.mark.parametrize("validation_mode,expected_status", [
("strict", WorkflowNodeExecutionStatus.FAILED),
("loose", WorkflowNodeExecutionStatus.SUCCEEDED)
])
def test_validation_modes(self, mock_graph_runtime_state, validation_mode, expected_status):
"""测试不同验证模式"""
config = {
"id": "test_node",
"data": {
"input_variable": "input_data",
"validation_rules": [
{
"field_path": "invalid_field",
"rule_type": "custom",
"expected_value": "invalid_python_code(", # 故意的语法错误
"error_message": "自定义验证失败"
}
],
"validation_mode": validation_mode,
"on_failure": "stop"
}
}
mock_variable = Mock()
mock_variable.to_object.return_value = {"test": "data"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
node = DataValidationNode(
id="test_node",
config=config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=mock_graph_runtime_state
)
result = node._run()
assert result.status == expected_status
4.2 集成测试实践
除了单元测试,我们还需要集成测试来验证节点在实际工作流中的表现:
# tests/integration_tests/workflow/test_data_validation_integration.py
import json
import pytest
from typing import Dict, Any
from core.workflow.graph_engine.graph_engine import GraphEngine
from tests.integration_tests.workflow.workflow_test_base import WorkflowTestBase
class TestDataValidationIntegration(WorkflowTestBase):
"""数据验证节点集成测试"""
def test_validation_in_complete_workflow(self):
"""测试验证节点在完整工作流中的运行"""
# 构建包含数据验证的工作流
workflow_config = {
"version": "0.1.0",
"environment_variables": [],
"conversation_variables": [],
"graph": {
"nodes": [
{
"id": "start",
"type": "start",
"data": {
"inputs": [
{
"variable": "user_input",
"type": "paragraph",
"label": "用户输入",
"required": True
}
]
}
},
{
"id": "parse_json",
"type": "code",
"data": {
"code_language": "python3",
"code": """
import json
def main(user_input: str) -> dict:
try:
data = json.loads(user_input)
return {"result": data, "success": True}
except:
return {"result": None, "success": False}
""",
"variables": [
{
"variable": "user_input",
"value_selector": ["start", "user_input"]
}
],
"outputs": {
"result": "object",
"success": "boolean"
}
}
},
{
"id": "validate_data",
"type": "data_validation",
"data": {
"input_variable": {
"variable": "parsed_data",
"value_selector": ["parse_json", "result"]
},
"validation_rules": [
{
"field_path": "name",
"rule_type": "required",
"error_message": "姓名不能为空"
},
{
"field_path": "email",
"rule_type": "regex",
"expected_value": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
"error_message": "邮箱格式无效"
}
],
"validation_mode": "strict",
"on_failure": "continue"
}
},
{
"id": "process_result",
"type": "llm",
"data": {
"model": {
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat"
},
"prompt_template": [
{
"role": "user",
"text": "根据验证结果处理用户数据:\n验证状态: {{#validate_data.is_valid#}}\n错误信息: {{#validate_data.errors#}}\n原始数据: {{#validate_data.data#}}"
}
],
"variables": [
{
"variable": "validation_result",
"value_selector": ["validate_data", "is_valid"]
},
{
"variable": "validation_errors",
"value_selector": ["validate_data", "errors"]
},
{
"variable": "original_data",
"value_selector": ["validate_data", "data"]
}
]
}
},
{
"id": "end",
"type": "end",
"data": {
"outputs": [
{
"variable": "final_result",
"value_selector": ["process_result", "text"]
}
]
}
}
],
"edges": [
{"source": "start", "target": "parse_json"},
{"source": "parse_json", "target": "validate_data"},
{"source": "validate_data", "target": "process_result"},
{"source": "process_result", "target": "end"}
]
}
}
# 测试有效数据
valid_input = '{"name": "张三", "email": "zhangsan@example.com", "age": 25}'
result = self.run_workflow(workflow_config, {"user_input": valid_input})
assert result.status == "succeeded"
assert "验证通过" in result.outputs.get("final_result", "")
# 测试无效数据
invalid_input = '{"name": "", "email": "invalid-email", "age": 25}'
result = self.run_workflow(workflow_config, {"user_input": invalid_input})
assert result.status == "succeeded" # 因为设置了 continue 模式
assert "验证失败" in result.outputs.get("final_result", "")
def test_validation_with_conditional_branch(self):
"""测试验证节点与条件分支的配合"""
workflow_config = {
# ... 包含 if/else 节点的工作流配置
# 根据验证结果进行不同的处理分支
}
# 测试验证成功分支
# 测试验证失败分支
pass
4.3 性能测试和压力测试
对于可能处理大量数据的节点,性能测试不可缺少:
# tests/performance_tests/test_data_validation_performance.py
import time
import pytest
from concurrent.futures import ThreadPoolExecutor, as_completed
class TestDataValidationPerformance:
"""数据验证节点性能测试"""
def test_large_dataset_validation(self):
"""测试大数据集验证性能"""
# 生成大量测试数据
large_dataset = [
{"id": i, "name": f"user_{i}", "email": f"user{i}@example.com"}
for i in range(10000)
]
# 配置验证规则
validation_rules = [
{"field_path": "id", "rule_type": "type", "expected_value": "number"},
{"field_path": "name", "rule_type": "required"},
{"field_path": "email", "rule_type": "regex",
"expected_value": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"}
]
start_time = time.time()
# 执行验证
node = self.create_validation_node(large_dataset, validation_rules)
result = node._run()
execution_time = time.time() - start_time
# 性能断言 - 10000条记录应在5秒内完成
assert execution_time < 5.0
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
print(f"验证10000条记录耗时: {execution_time:.2f}秒")
def test_concurrent_validation(self):
"""测试并发验证场景"""
def run_validation():
test_data = {"name": "测试用户", "age": 25}
node = self.create_validation_node(test_data, self.basic_rules)
return node._run()
# 并发执行50个验证任务
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(run_validation) for _ in range(50)]
success_count = 0
for future in as_completed(futures):
result = future.result()
if result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
success_count += 1
assert success_count == 50
五、节点开发最佳实践总结
通过前面的实战开发,我总结出以下最佳实践:
5.1 代码质量保证
- 类型安全:充分利用 Python 的类型提示和 Pydantic 验证
- 错误处理:完善的异常捕获和用户友好的错误信息
- 日志记录:关键操作点都要有适当的日志记录
- 参数验证:输入参数的严格验证,防止运行时错误
5.2 性能优化策略
# 性能优化示例
class OptimizedValidationNode(DataValidationNode):
"""性能优化版本的验证节点"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 预编译正则表达式
self._compiled_regexes = {}
self._precompile_regexes()
def _precompile_regexes(self):
"""预编译正则表达式以提高性能"""
for rule in self.node_data.validation_rules:
if rule.rule_type == "regex" and rule.expected_value:
try:
self._compiled_regexes[rule.field_path] = re.compile(rule.expected_value)
except re.error as e:
logger.warning(f"正则表达式编译失败: {rule.expected_value} - {e}")
def _apply_validation_rule(self, value: Any, rule: ValidationRule) -> bool:
"""优化的规则应用方法"""
if rule.rule_type == "regex":
# 使用预编译的正则表达式
compiled_regex = self._compiled_regexes.get(rule.field_path)
if compiled_regex:
return bool(compiled_regex.match(str(value)))
return super()._apply_validation_rule(value, rule)
5.3 可扩展性设计
# 插件化验证规则
class ValidationRuleRegistry:
"""验证规则注册表"""
_rules = {}
@classmethod
def register(cls, rule_type: str):
"""注册自定义验证规则"""
def decorator(func):
cls._rules[rule_type] = func
return func
return decorator
@classmethod
def get_rule(cls, rule_type: str):
return cls._rules.get(rule_type)
# 注册自定义规则
@ValidationRuleRegistry.register("custom_phone")
def validate_phone_number(value: Any, expected_value: Any) -> bool:
"""验证中国手机号码格式"""
if not isinstance(value, str):
return False
return bool(re.match(r'^1[3-9]\d{9}, value))
# 在节点中使用
def _apply_validation_rule(self, value: Any, rule: ValidationRule) -> bool:
# 优先使用注册的自定义规则
custom_rule = ValidationRuleRegistry.get_rule(rule.rule_type)
if custom_rule:
return custom_rule(value, rule.expected_value)
# 回退到默认规则
return super()._apply_validation_rule(value, rule)
5.4 前端集成优化
// 节点配置组件的性能优化
const DataValidationNodeConfig = React.memo(({ data, onChange }) => {
// 使用 useMemo 优化规则列表渲染
const optimizedRules = useMemo(() => {
return data.validation_rules.map((rule, index) => ({
...rule,
id: `${rule.field_path}_${index}` // 稳定的key
}))
}, [data.validation_rules])
// 防抖的onChange处理
const debouncedOnChange = useMemo(
() => debounce(onChange, 300),
[onChange]
)
return (
<div className="space-y-4">
{/* 配置界面组件 */}
</div>
)
})
六、结语
自定义节点开发是 Dify 扩展能力的核心。通过本章的学习,你应该能够:
- 理解节点架构:掌握 Dify 节点系统的设计原理
- 开发实用节点:能够开发满足实际业务需求的自定义节点
- 保证代码质量:通过测试确保节点的可靠性和性能
- 遵循最佳实践:写出可维护、可扩展的高质量代码
给想要深入的朋友几个建议:
- 多看源码:Dify 的内置节点实现都很精彩,值得仔细研读
- 注重测试:节点代码的测试覆盖率要达到 80% 以上
- 关注性能:大数据处理场景下的性能优化不可忽视
- 积极贡献:把你开发的优秀节点贡献给社区
下一章,我们将深入探讨 Provider 接入开发,学习如何为 Dify 添加新的模型提供商支持。相信你会发现,Provider 的设计同样精彩!
如果你在节点开发过程中遇到问题,欢迎在评论区交流。记住,最好的学习方式就是动手实践 - 现在就开始开发你的第一个自定义节点吧!