原文链接地址:Spring Ai Alibaba Graph源码解读系列—核心启动类
[!TIP]
取自 Spring Ai Alibaba 1.0.0.3 版本
核心类

OverAllState
图执行过程中的核心状态管理类,用于存储和管理图中各个节点间共享的数据
| 字段名称 | 字段类型 | 描述 |
| data | Map |
对外暴露的方法
| 方法名称 | 描述 | |
| 构造方法 | OverAllState | 支持四种方式构造 - 无参: 默认构造函数,初始化空数据和策略映射,并注册默认输入键 - Map |
内部静态类 HumanFeedback,处理和存储在工作流执行过程中来自人工反馈的信息
- Map<String, Object> data:存储人工反馈的具体数据内容,可以包含任意键值对形式的数据
- String nextNodeId:指定下一个要执行的节点 ID
package com.alibaba.cloud.ai.graph;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import static com.alibaba.cloud.ai.graph.utils.CollectionsUtils.entryOf;
import static java.util.Collections.unmodifiableMap;
import static java.util.Optional.ofNullable;
public final class OverAllState implements Serializable {
/**
* Internal map storing the actual state data. All get/set operations on state values
* go through this map.
*/
private final Map<String, Object> data;
/**
* Mapping of keys to their respective update strategies. Determines how values for
* each key should be merged or updated.
*/
private final Map<String, KeyStrategy> keyStrategies;
/**
* Indicates whether this state is being used to resume a previously interrupted
* execution. If true, certain initialization steps may be skipped.
*/
private Boolean resume;
/**
* Holds optional human feedback information provided during execution. May be null if
* no feedback was given.
*/
private HumanFeedback humanFeedback;
/**
* Optional message indicating that the execution was interrupted. If non-null,
* indicates that the graph should halt or handle the interruption.
*/
private String interruptMessage;
/**
* The default key used for standard input injection into the state. Typically used
* when initializing the state with user or external input.
*/
public static final String DEFAULTINPUTKEY = "input";
/**
* Reset.
*/
public void reset() {
this.data.clear();
}
/**
* Snap shot optional.
* @return the optional
*/
public Optional<OverAllState> snapShot() {
return Optional.of(new OverAllState(new HashMap<>(this.data), new HashMap<>(this.keyStrategies), this.resume));
}
/**
* Instantiates a new Over all state.
* @param resume the is resume
*/
public OverAllState(boolean resume) {
this.data = new HashMap<>();
this.keyStrategies = new HashMap<>();
this.resume = resume;
}
/**
* Instantiates a new Over all state.
* @param data the data
*/
public OverAllState(Map<String, Object> data) {
this.data = new HashMap<>(data);
this.keyStrategies = new HashMap<>();
this.resume = false;
}
/**
* Instantiates a new Over all state.
*/
public OverAllState() {
this.data = new HashMap<>();
this.keyStrategies = new HashMap<>();
this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
this.resume = false;
}
/**
* Instantiates a new Over all state.
* @param data the data
* @param keyStrategies the key strategies
* @param resume the resume
*/
protected OverAllState(Map<String, Object> data, Map<String, KeyStrategy> keyStrategies, Boolean resume) {
this.data = data;
this.keyStrategies = keyStrategies;
this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
this.resume = resume;
}
/**
* Interrupt message string.
* @return the string
*/
public String interruptMessage() {
return interruptMessage;
}
/**
* Sets interrupt message.
* @param interruptMessage the interrupt message
*/
public void setInterruptMessage(String interruptMessage) {
this.interruptMessage = interruptMessage;
}
/**
* With human feedback.
* @param humanFeedback the human feedback
*/
public void withHumanFeedback(HumanFeedback humanFeedback) {
this.humanFeedback = humanFeedback;
}
/**
* Human feedback human feedback.
* @return the human feedback
*/
public HumanFeedback humanFeedback() {
return this.humanFeedback;
}
/**
* Copy with resume over all state.
* @return the over all state
*/
public OverAllState copyWithResume() {
return new OverAllState(this.data, this.keyStrategies, true);
}
/**
* With resume.
*/
public void withResume() {
this.resume = true;
}
/**
* Without resume.
*/
public void withoutResume() {
this.resume = false;
}
/**
* Is resume boolean.
* @return the boolean
*/
public boolean isResume() {
return this.resume;
}
/**
* Clears all data in the current state, leaving key strategies, resume flag, and
* human feedback intact.
*/
public void clear() {
this.data.clear();
}
/**
* Replaces the current state's contents with the provided state.
* <p>
* This method effectively copies all data, key strategies, resume flag, and human
* feedback from the provided state to this state.
* @param overAllState the state to copy from
*/
public void cover(OverAllState overAllState) {
this.keyStrategies.clear();
this.keyStrategies.putAll(overAllState.keyStrategies());
this.data.clear();
this.data.putAll(overAllState.data());
this.resume = overAllState.resume;
this.humanFeedback = overAllState.humanFeedback;
}
/**
* Inputs over all state.
* @param input the input
* @return the over all state
*/
public OverAllState input(Map<String, Object> input) {
if (input == null) {
withResume();
return this;
}
if (CollectionUtils.isEmpty(input)) {
return this;
}
Map<String, KeyStrategy> keyStrategies = keyStrategies();
input.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
this.data.put(key, keyStrategies.get(key).apply(value(key, null), input.get(key)));
});
return this;
}
/**
* Add key and strategy over all state.
* @param key the key
* @param strategy the strategy
* @return the over all state
*/
public OverAllState registerKeyAndStrategy(String key, KeyStrategy strategy) {
this.keyStrategies.put(key, strategy);
return this;
}
/**
* Register key and strategy over all state.
* @param keyStrategies the key strategies
* @return the over all state
*/
public OverAllState registerKeyAndStrategy(Map<String, KeyStrategy> keyStrategies) {
this.keyStrategies.putAll(keyStrategies);
return this;
}
/**
* Is contain strategy boolean.
* @param key the key
* @return the boolean
*/
public boolean containStrategy(String key) {
return this.keyStrategies.containsKey(key);
}
/**
* Update state map.
* @param partialState the partial state
* @return the map
*/
public Map<String, Object> updateState(Map<String, Object> partialState) {
Map<String, KeyStrategy> keyStrategies = keyStrategies();
partialState.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
this.data.put(key, keyStrategies.get(key).apply(value(key, null), partialState.get(key)));
});
return data();
}
/**
* Updates the internal state based on a schema-defined strategy.
* <p>
* This method first validates the input state, then updates the partial state
* according to the provided key strategies. The updated state is formed by merging
* the original state and the modified partial state, removing any null values in the
* process. The resulting entries are then used to update the internal data map.
* @param state the base state to update; must not be null
* @param partialState the partial state containing updates; may be null or empty
* @param keyStrategies the mapping of keys to update strategies; used to transform
* values
*/
public void updateStateBySchema(Map<String, Object> state, Map<String, Object> partialState,
Map<String, KeyStrategy> keyStrategies) {
updateState(updateState(state, partialState, keyStrategies));
}
/**
* Key verify boolean.
* @return the boolean
*/
protected boolean keyVerify() {
return hasCommonKey(this.data, getKeyStrategies());
}
private Map<?, ?> getKeyStrategies() {
return this.keyStrategies;
}
private boolean hasCommonKey(Map<?, ?> map1, Map<?, ?> map2) {
Set<?> keys1 = map1.keySet();
for (Object key : map2.keySet()) {
if (keys1.contains(key)) {
return true;
}
}
return false;
}
/**
* Updates a state with the provided partial state. The merge function is used to
* merge the current state value with the new value.
* @param state the current state
* @param partialState the partial state to update from
* @return the updated state
* @throws NullPointerException if state is null
*/
public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState) {
Objects.requireNonNull(state, "state cannot be null");
if (partialState == null || partialState.isEmpty()) {
return state;
}
return Stream.concat(state.entrySet().stream(), partialState.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, OverAllState::mergeFunction));
}
/**
* Update state map.
* @param state the state
* @param partialState the partial state
* @param keyStrategies the key strategies
* @return the map
*/
public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState,
Map<String, KeyStrategy> keyStrategies) {
Objects.requireNonNull(state, "state cannot be null");
if (partialState == null || partialState.isEmpty()) {
return state;
}
Map<String, Object> updatedPartialState = updatePartialStateFromSchema(state, partialState, keyStrategies);
return Stream.concat(state.entrySet().stream(), updatedPartialState.entrySet().stream())
.collect(toMapRemovingNulls(Map.Entry::getKey, Map.Entry::getValue, (currentValue, newValue) -> newValue));
}
/**
* Updates the partial state from a schema using channels.
* @param state The current state as a map of key-value pairs.
* @param partialState The partial state to be updated.
* @param keyStrategies A map of channel names to their implementations.
* @return An updated version of the partial state after applying the schema and
* channels.
*/
private static Map<String, Object> updatePartialStateFromSchema(Map<String, Object> state,
Map<String, Object> partialState, Map<String, KeyStrategy> keyStrategies) {
if (keyStrategies == null || keyStrategies.isEmpty()) {
return partialState;
}
return partialState.entrySet().stream().map(entry -> {
KeyStrategy channel = keyStrategies.get(entry.getKey());
if (channel != null) {
Object newValue = channel.apply(state.get(entry.getKey()), entry.getValue());
return entryOf(entry.getKey(), newValue);
}
return entry;
}).collect(toMapAllowingNulls(Map.Entry::getKey, Map.Entry::getValue));
}
private static <T, K, U> Collector<T, ?, Map<K, U>> toMapRemovingNulls(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper, BinaryOperator<U> mergeFunction) {
return Collector.of(HashMap::new, (map, element) -> {
K key = keyMapper.apply(element);
U value = valueMapper.apply(element);
if (value == null) {
map.remove(key);
}
else {
map.merge(key, value, mergeFunction);
}
}, (map1, map2) -> {
map2.forEach((key, value) -> {
if (value != null) {
map1.merge(key, value, mergeFunction);
}
});
return map1;
}, Collector.Characteristics.UNORDERED);
}
private static <T, K, U> Collector<T, ?, Map<K, U>> toMapAllowingNulls(Function<? super T, ? extends K> keyMapper,
Function<? super T, ? extends U> valueMapper) {
return Collector.of(HashMap::new,
(map, element) -> map.put(keyMapper.apply(element), valueMapper.apply(element)), (map1, map2) -> {
map1.putAll(map2);
return map1;
}, Collector.Characteristics.UNORDERED);
}
/**
* Merges the current value with the new value using the appropriate merge function.
* @param currentValue the current value
* @param newValue the new value
* @return the merged value
*/
private static Object mergeFunction(Object currentValue, Object newValue) {
return newValue;
}
/**
* Key strategies map.
* @return the map
*/
public Map<String, KeyStrategy> keyStrategies() {
return keyStrategies;
}
/**
* Data map.
* @return the map
*/
public final Map<String, Object> data() {
return unmodifiableMap(data);
}
/**
* Value optional.
* @param <T> the type parameter
* @param key the key
* @return the optional
*/
public final <T> Optional<T> value(String key) {
return ofNullable((T) data().get(key));
}
/**
* Value optional.
* @param <T> the type parameter
* @param key the key
* @param type the type
* @return the optional
*/
public final <T> Optional<T> value(String key, Class<T> type) {
if (type != null) {
return ofNullable(type.cast(data().get(key)));
}
return value(key);
}
/**
* Value t.
* @param <T> the type parameter
* @param key the key
* @param defaultValue the default value
* @return the t
*/
public final <T> T value(String key, T defaultValue) {
return (T) value(key).orElse(defaultValue);
}
/**
* The type Human feedback.
*/
public static class HumanFeedback implements Serializable {
private Map<String, Object> data;
private String nextNodeId;
private String currentNodeId;
/**
* Instantiates a new Human feedback.
* @param data the data
* @param nextNodeId the next node id
*/
public HumanFeedback(Map<String, Object> data, String nextNodeId) {
this.data = data;
this.nextNodeId = nextNodeId;
}
/**
* Data map.
* @return the map
*/
public Map<String, Object> data() {
return data;
}
/**
* Next node id string.
* @return the string
*/
public String nextNodeId() {
return nextNodeId;
}
/**
* Sets data.
* @param data the data
*/
public void setData(Map<String, Object> data) {
this.data = data;
}
/**
* Sets next node id.
* @param nextNodeId the next node id
*/
public void setNextNodeId(String nextNodeId) {
this.nextNodeId = nextNodeId;
}
}
@Override
public String toString() {
return "OverAllState{" + "data=" + data + ", resume=" + resume + ", humanFeedback=" + humanFeedback
+ ", interruptMessage='" + interruptMessage + '\'' + '}';
}
}
RunnableConfig
运行配置类
| 字段名称 | 字段类型 | 描述 |
| threadId | String | 线程ID |
| checkPointId | String | 检查点ID |
| nextNode | String | 下一个要执行的节点ID |
| streamMode | CompiledGraph.StreamMode | 编译图的流模式,详情可见CompiledGraph类说明 |
| metadata | Map |
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.ofNullable;
/**
* A final class representing configuration for a runnable task. This class holds various
* parameters such as thread ID, checkpoint ID, next node, and stream mode, providing
* methods to modify these parameters safely without permanently altering the original
* configuration.
*/
public final class RunnableConfig implements HasMetadata<RunnableConfig.Builder> {
private final String threadId;
private final String checkPointId;
private final String nextNode;
private final CompiledGraph.StreamMode streamMode;
private final Map<String, Object> metadata;
private final Map<String, Object> interruptedNodes;
/**
* Returns the stream mode of the compiled graph.
* @return {@code StreamMode} representing the current stream mode.
*/
public CompiledGraph.StreamMode streamMode() {
return streamMode;
}
/**
* Returns the thread ID as an {@link Optional}.
* @return the thread ID wrapped in an {@code Optional}, or an empty {@code Optional}
* if no thread ID is set.
*/
public Optional<String> threadId() {
return ofNullable(threadId);
}
/**
* Returns the current {@code checkPointId} wrapped in an {@link Optional}.
* @return an {@link Optional} containing the {@code checkPointId}, or
* {@link Optional#empty()} if it is null.
*/
public Optional<String> checkPointId() {
return ofNullable(checkPointId);
}
/**
* Returns an {@code Optional} describing the next node in the sequence, or an empty
* {@code Optional} if there is no such node.
* @return an {@code Optional} describing the next node, or an empty {@code Optional}
*/
public Optional<String> nextNode() {
return ofNullable(nextNode);
}
/**
* Checks if a node is marked as interrupted in the metadata.
* @param nodeId the ID of the node to check for interruption status
* @return true if the node is marked as interrupted, false otherwise
*/
public boolean isInterrupted(String nodeId) {
return interruptData(HasMetadata.formatNodeId(nodeId)).map(value -> Boolean.TRUE.equals(value)).orElse(false);
}
/**
* Marks a node as not interrupted by setting its value to false in the metadata.
* @param nodeId the ID of the node to mark as not interrupted
* @return a new {@code RunnableConfig} instance with the updated metadata
*/
public void withNodeResumed(String nodeId) {
String formattedNodeId = HasMetadata.formatNodeId(nodeId);
interruptedNodes.put(formattedNodeId, false);
}
/**
* Removes the interrupted marker for a specific node by removing its entry from the
* metadata.
* @param nodeId the ID of the node to remove the interrupted marker for
* @return a new {@code RunnableConfig} instance with the updated metadata
*/
public void removeInterrupted(String nodeId) {
String formattedNodeId = HasMetadata.formatNodeId(nodeId);
if (interruptedNodes == null || !interruptedNodes.containsKey(formattedNodeId)) {
return; // No change needed if the marker doesn't exist
}
interruptedNodes.remove(formattedNodeId);
}
/**
* Marks a node as interrupted by adding it to the metadata with a formatted key. The
* node ID is formatted using {@link #formatNodeId(String)} and associated with a
* value of {@code true} in the metadata map.
* @param nodeId the ID of the node to mark as interrupted; must not be null
* @return this {@code Builder} instance for method chaining
* @throws NullPointerException if nodeId is null
*/
public void markNodeAsInterrupted(String nodeId) {
interruptedNodes.put(HasMetadata.formatNodeId(nodeId), true);
}
/**
* Create a new RunnableConfig with the same attributes as this one but with a
* different {@link CompiledGraph.StreamMode}.
* @param streamMode the new stream mode
* @return a new RunnableConfig with the updated stream mode
*/
public RunnableConfig withStreamMode(CompiledGraph.StreamMode streamMode) {
if (this.streamMode == streamMode) {
return this;
}
return RunnableConfig.builder(this).streamMode(streamMode).build();
}
/**
* Updates the checkpoint ID of the configuration.
* @param checkPointId The new checkpoint ID to set.
* @return A new instance of {@code RunnableConfig} with the updated checkpoint ID, or
* the current instance if no change is needed.
*/
public RunnableConfig withCheckPointId(String checkPointId) {
if (Objects.equals(this.checkPointId, checkPointId)) {
return this;
}
return RunnableConfig.builder(this).checkPointId(checkPointId).build();
}
/**
* Retrieves interrupt data associated with the specified key.
* @param key the key for which to retrieve interrupt data; may be null
* @return an Optional containing the interrupt data if present, or an empty Optional
* if the key is null or no data is found
*/
public Optional<Object> interruptData(String key) {
if (key == null) {
return Optional.empty();
}
return ofNullable(interruptedNodes).map(m -> m.get(key));
}
/**
* return metadata value for key
* @param key given metadata key
* @return metadata value for key if any
*/
@Override
public Optional<Object> metadata(String key) {
if (key == null) {
return Optional.empty();
}
return ofNullable(metadata).map(m -> m.get(key));
}
/**
* Creates a new instance of the {@link Builder} class.
* @return A new {@code Builder} object.
*/
public static Builder builder() {
return new Builder();
}
/**
* Creates a new {@code Builder} instance with the specified {@link RunnableConfig}.
* @param config The configuration for the {@code Builder}.
* @return A new {@code Builder} instance.
*/
public static Builder builder(RunnableConfig config) {
return new Builder(config);
}
/**
* A builder pattern class for constructing {@link RunnableConfig} objects. This class
* provides a fluent interface to set various properties of a {@link RunnableConfig}
* object and then build the final configuration.
*/
public static class Builder extends HasMetadata.Builder<Builder> {
private String threadId;
private String checkPointId;
private String nextNode;
private CompiledGraph.StreamMode streamMode = CompiledGraph.StreamMode.VALUES;
/**
* Constructs a new instance of the {@link Builder} with default configuration
* settings. Initializes a new {@link RunnableConfig} object for configuration
* purposes.
*/
Builder() {
}
/**
* Initializes a new instance of the {@code Builder} class with the specified
* {@link RunnableConfig}.
* @param config The configuration to be used for initialization.
*/
Builder(RunnableConfig config) {
super(requireNonNull(config, "config cannot be null!").metadata);
this.threadId = config.threadId;
this.checkPointId = config.checkPointId;
this.nextNode = config.nextNode;
this.streamMode = config.streamMode;
}
/**
* Sets the ID of the thread.
* @param threadId the ID of the thread to set
* @return a reference to this {@code Builder} object so that method calls can be
* chained together
*/
public Builder threadId(String threadId) {
this.threadId = threadId;
return this;
}
/**
* Sets the checkpoint ID for the configuration.
* @param {@code checkPointId} - the ID of the checkpoint to be set
* @return {@literal this} - a reference to the current `Builder` instance
*/
public Builder checkPointId(String checkPointId) {
this.checkPointId = checkPointId;
return this;
}
/**
* Sets the next node in the configuration and returns this builder for method
* chaining.
* @param nextNode The next node to be set.
* @return This builder instance, allowing for method chaining.
*/
public Builder nextNode(String nextNode) {
this.nextNode = nextNode;
return this;
}
/**
* Sets the stream mode of the configuration.
* @param streamMode The {@link CompiledGraph.StreamMode} to set.
* @return A reference to this builder for method chaining.
*/
public Builder streamMode(CompiledGraph.StreamMode streamMode) {
this.streamMode = streamMode;
return this;
}
/**
* Adds a custom {@link Executor} for a specific parallel node.
* <p>
* This allows you to control the execution of branches within a parallel node.
* When a parallel node is executed, it will look for an executor in the
* {@link RunnableConfig} metadata. If found, it will be used to run the parallel
* branches concurrently.
* @param nodeId the ID of the parallel node.
* @param executor the {@link Executor} to use for the parallel node.
* @return this {@code Builder} instance for method chaining.
*/
public Builder addParallelNodeExecutor(String nodeId, Executor executor) {
return addMetadata(ParallelNode.formatNodeId(nodeId), requireNonNull(executor, "executor cannot be null!"));
}
/**
* Constructs and returns the configured {@code RunnableConfig} object.
* @return the configured {@code RunnableConfig} object
*/
public RunnableConfig build() {
return new RunnableConfig(this);
}
}
/**
* Creates a new instance of {@code RunnableConfig} as a copy of the provided
* {@code config}.
* @param builder The configuration builder.
*/
private RunnableConfig(Builder builder) {
this.threadId = builder.threadId;
this.checkPointId = builder.checkPointId;
this.nextNode = builder.nextNode;
this.streamMode = builder.streamMode;
this.metadata = ofNullable(builder.metadata()).map(Map::copyOf).orElse(null);
this.interruptedNodes = new ConcurrentHashMap<>();
}
@Override
public String toString() {
return format("RunnableConfig{ threadId=%s, checkPointId=%s, nextNode=%s, streamMode=%s }", threadId,
checkPointId, nextNode, streamMode);
}
}
StateGraph
用于表示和构建基于状态的图结构工作流,有如下功能
- 图结构定义:提供构建有向图的 API,包括节点和边的定义
- 状态管理:与 OverAllState 配合管理图执行过程中的状态
- 工作流编排:支持定义复杂的工作流执行逻辑,包括条件边和子图
| 字段名称 | 字段类型 | 描述 |
| START | String | 图的起始节点标识常量("START") |
| END | String | 图的结束节点标识常量("END") |
| ERROR | String | 图的错误节点标识常量("ERROR") |
| NODEBEFORE | String | 节点执行前钩子标识常量("NODEBEFORE") |
| NODEAFTER | String | 节点执行后钩子标识常量("NODEAFTER") |
| nodes | Nodes | 图中所有节点的容器 |
| edges | Edges | 图中所有边的容器 |
| overAllStateFactory | OverAllStateFactory | 创建整体状态实例的工厂(已废弃) |
| keyStrategyFactory | KeyStrategyFactory | 供键策略的工厂 |
| name | String | 图的名称 |
| stateSerializer | PlainTextStateSerializer | 内部类,基于Jackson的状态序列化器 |
对外暴露的方法
| 方法名称 | 描述 | |
| 构造 | StateGraph | 支持五种方式构造 - 无参 - KeyStrategyFactory keyStrategyFactory:带键策略工厂 - (String name, KeyStrategyFactory keyStrategyFactory):带名称和键策略工厂 - (KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer):带键策略工厂、序列化器 - (String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) :带键策略工厂、名称、序列化器 |
| 节点管理 | addNode | - (String id, Node node):添加节点实例 - (String id, AsyncNodeAction action):添加普通节点 - (String id, AsyncNodeActionWithConfig actionWithConfig):添加带配置的节点 - (String id, AsyncCommandAction action, Map |
静态内部类 Nodes,管理图中的节点集合
- anyMatchById(String id):检查是否存在指定 ID 的节点
- onlySubStateGraphNodes():获取所有子图节点
- exceptSubStateGraphNodes():获取移除子图之外的所有节点
静态内部类 Edges,管理图中边集合
- edgeBySourceId(String sourceId):根据源节点 ID 查找边
- edgesByTargetId(String targetId):根据目标节点 ID 查找边
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeCondition;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.SubCompiledGraphNode;
import com.alibaba.cloud.ai.graph.internal.node.SubStateGraphNode;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.jackson.JacksonStateSerializer;
import com.alibaba.cloud.ai.graph.state.AgentStateFactory;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.LinkedHashSet;
/**
* Represents a state graph with nodes and edges.
*/
public class StateGraph {
/**
* Constant representing the END of the graph.
*/
public static final String END = "END";
/**
* Constant representing the START of the graph.
*/
public static final String START = "START";
/**
* Constant representing the ERROR of the graph.
*/
public static final String ERROR = "ERROR";
/**
* Constant representing the NODEBEFORE of the graph.
*/
public static final String NODEBEFORE = "NODEBEFORE";
/**
* Constant representing the NODEAFTER of the graph.
*/
public static final String NODEAFTER = "NODEAFTER";
/**
* Collection of nodes in the graph.
*/
final Nodes nodes = new Nodes();
/**
* Collection of edges in the graph.
*/
final Edges edges = new Edges();
/**
* Factory for creating overall state instances.
*/
private OverAllStateFactory overAllStateFactory;
/**
* Factory for providing key strategies.
*/
private KeyStrategyFactory keyStrategyFactory;
/**
* Name of the graph.
*/
private String name;
/**
* Serializer for the state.
*/
private final PlainTextStateSerializer stateSerializer;
/**
* Jackson-based serializer for state.
*/
static class JacksonSerializer extends JacksonStateSerializer {
/**
* Instantiates a new Jackson serializer.
*/
public JacksonSerializer() {
super(OverAllState::new);
}
/**
* Gets object mapper.
* @return the object mapper
*/
ObjectMapper getObjectMapper() {
return objectMapper;
}
}
/**
* Constructs a StateGraph with the specified name, key strategy factory, and state
* serializer.
* @param name the name of the graph
* @param keyStrategyFactory the factory for providing key strategies
* @param stateSerializer the plain text state serializer to use
*/
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.name = name;
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
/**
* Constructs a StateGraph with the given key strategy factory and name.
* @param keyStrategyFactory the factory for providing key strategies
* @param name the name of the graph
*/
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
this.keyStrategyFactory = keyStrategyFactory;
this.name = name;
this.stateSerializer = new JacksonSerializer();
}
/**
* Constructs a StateGraph with the provided key strategy factory.
* @param keyStrategyFactory the factory for providing key strategies
*/
public StateGraph(KeyStrategyFactory keyStrategyFactory) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the specified name,
* overall state factory, and state serializer.
* @param name the name of the graph
* @param overAllStateFactory the factory for creating overall state instances
* @param plainTextStateSerializer the plain text state serializer to use
*/
@Deprecated
public StateGraph(String name, OverAllStateFactory overAllStateFactory,
PlainTextStateSerializer plainTextStateSerializer) {
this.name = name;
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = plainTextStateSerializer;
}
/**
* Deprecated constructor that initializes a StateGraph with the specified name and
* overall state factory.
* @param name the name of the graph
* @param overAllStateFactory the factory for creating overall state instances
*/
@Deprecated
public StateGraph(String name, OverAllStateFactory overAllStateFactory) {
this.name = name;
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the provided overall
* state factory.
* @param overAllStateFactory the factory for creating overall state instances
*/
@Deprecated
public StateGraph(OverAllStateFactory overAllStateFactory) {
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Deprecated constructor that initializes a StateGraph with the provided overall
* state factory and state serializer.
* @param overAllStateFactory the factory for creating overall state instances
* @param plainTextStateSerializer the plain text state serializer to use
*/
@Deprecated
public StateGraph(OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
this.overAllStateFactory = overAllStateFactory;
this.stateSerializer = plainTextStateSerializer;
}
/**
* Default constructor that initializes a StateGraph with a Gson-based state
* serializer.
*/
public StateGraph() {
this.stateSerializer = new JacksonSerializer();
this.keyStrategyFactory = HashMap::new;
}
/**
* Gets the name of the graph.
* @return the name
*/
public String getName() {
return name;
}
/**
* Gets the state serializer used by this graph.
* @return the state serializer
*/
public StateSerializer<OverAllState> getStateSerializer() {
return stateSerializer;
}
/**
* Gets the state factory associated with this graph's state serializer.
* @return the state factory
*/
public final AgentStateFactory<OverAllState> getStateFactory() {
return stateSerializer.stateFactory();
}
/**
* Gets the overall state factory.
* @return the overall state factory
*/
@Deprecated
public final OverAllStateFactory getOverAllStateFactory() {
return overAllStateFactory;
}
/**
* Gets the key strategy factory.
* @return the key strategy factory
*/
public final KeyStrategyFactory getKeyStrategyFactory() {
return keyStrategyFactory;
}
/**
* Adds a commandNode to the graph.
* @param id the identifier of the node
* @param action AsyncCommandAction action
* @param mappings the mappings to be used for conditional edges
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncCommandAction action, Map<String, String> mappings)
throws GraphStateException {
return addNode(id, new CommandNode(id, action, mappings));
}
/**
* Adds a node to the graph.
* @param id the identifier of the node
* @param action the asynchronous node action to be performed by the node
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
return addNode(id, AsyncNodeActionWithConfig.of(action));
}
/**
* Adds a node to the graph with the specified action and configuration.
* @param id the identifier of the node
* @param actionWithConfig the action to be performed by the node
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
Node node = new Node(id, (config) -> actionWithConfig);
return addNode(id, node);
}
/**
* Adds a node to the graph with the specified identifier and node instance.
* @param id the identifier of the node
* @param node the node to be added
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, Node node) throws GraphStateException {
if (Objects.equals(node.id(), END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
if (!Objects.equals(node.id(), id)) {
throw Errors.invalidNodeIdentifier.exception(node.id(), id);
}
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that the subgraph shares the same state with the parent
* graph.
* @param id the identifier of the node representing the subgraph
* @param subGraph the compiled subgraph to be added
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
var node = new SubCompiledGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds a subgraph to the state graph by creating a node with the specified
* identifier. This implies that the subgraph will share the same state with the
* parent graph and will be compiled when the parent is compiled.
* @param id the identifier of the node representing the subgraph
* @param subGraph the subgraph to be added; it will be compiled during compilation of
* the parent
* @return this state graph instance
* @throws GraphStateException if the node identifier is invalid or the node already
* exists
*/
public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
subGraph.validateGraph();
var node = new SubStateGraphNode(id, subGraph);
if (nodes.elements.contains(node)) {
throw Errors.duplicateNodeError.exception(id);
}
nodes.elements.add(node);
return this;
}
/**
* Adds an edge to the graph between the specified source and target nodes.
* @param sourceId the identifier of the source node
* @param targetId the identifier of the target node
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid or the edge already
* exists
*/
public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
var newEdge = new Edge(sourceId, new EdgeValue(targetId));
int index = edges.elements.indexOf(newEdge);
if (index >= 0) {
var newTargets = new ArrayList<>(edges.elements.get(index).targets());
newTargets.add(newEdge.target());
edges.elements.set(index, new Edge(sourceId, newTargets));
}
else {
edges.elements.add(newEdge);
}
return this;
}
/**
* Adds conditional edges to the graph based on the provided condition and mappings.
* @param sourceId the identifier of the source node
* @param condition the command action used to determine the target node
* @param mappings the mappings of conditions to target nodes
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid, the mappings are
* empty, or the edge already exists
*/
public StateGraph addConditionalEdges(String sourceId, AsyncCommandAction condition, Map<String, String> mappings)
throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
if (mappings == null || mappings.isEmpty()) {
throw Errors.edgeMappingIsEmpty.exception(sourceId);
}
var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));
if (edges.elements.contains(newEdge)) {
throw Errors.duplicateConditionalEdgeError.exception(sourceId);
}
else {
edges.elements.add(newEdge);
}
return this;
}
/**
* Adds conditional edges to the graph based on the provided edge action condition and
* mappings.
* @param sourceId the identifier of the source node
* @param condition the edge action used to determine the target node
* @param mappings the mappings of conditions to target nodes
* @return this state graph instance
* @throws GraphStateException if the edge identifier is invalid, the mappings are
* empty, or the edge already exists
*/
public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
throws GraphStateException {
return addConditionalEdges(sourceId, AsyncCommandAction.of(condition), mappings);
}
/**
* Validates the structure of the graph ensuring all connections are valid.
* @throws GraphStateException if there are errors related to the graph state
*/
void validateGraph() throws GraphStateException {
var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);
edgeStart.validate(nodes);
validateNode(nodes);
for (Edge edge : edges.elements) {
edge.validate(nodes);
}
}
private void validateNode(Nodes nodes) throws GraphStateException {
List<CommandNode> commandNodeList = nodes.elements.stream().filter(node -> {
return node instanceof CommandNode commandNode;
}).map(node -> (CommandNode) node).toList();
for (CommandNode commandNode : commandNodeList) {
for (String key : commandNode.getMappings().keySet()) {
if (!nodes.anyMatchById(key)) {
throw Errors.missingNodeInEdgeMapping.exception(commandNode.id(), key);
}
}
}
}
/**
* Compiles the state graph into a compiled graph using the provided configuration.
* @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile(CompileConfig config) throws GraphStateException {
Objects.requireNonNull(config, "config cannot be null");
validateGraph();
return new CompiledGraph(this, config);
}
/**
* Compiles the state graph into a compiled graph using a default configuration with
* memory saver.
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph compile() throws GraphStateException {
SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
return compile(CompileConfig.builder().saverConfig(saverConfig).build());
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @param printConditionalEdges whether to include conditional edges in the output
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
String content = type.generator.generate(nodes, edges, title, printConditionalEdges);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph with conditional edges
* included.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
String content = type.generator.generate(nodes, edges, title, true);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph using the graph's name
* as title.
* @param type the type of graph representation to generate
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type) {
String content = type.generator.generate(nodes, edges, name, true);
return new GraphRepresentation(type, content);
}
/**
* Container for nodes in the graph.
*/
public static class Nodes {
/**
* The collection of nodes.
*/
public final Set<Node> elements;
/**
* Instantiates a new Nodes container with the provided elements.
* @param elements the elements to initialize
*/
public Nodes(Collection<Node> elements) {
this.elements = new LinkedHashSet<>(elements);
}
/**
* Instantiates a new empty Nodes container.
*/
public Nodes() {
this.elements = new LinkedHashSet<>();
}
/**
* Checks if any node matches the given identifier.
* @param id the identifier to match
* @return true if a matching node is found, false otherwise
*/
public boolean anyMatchById(String id) {
return elements.stream().anyMatch(n -> Objects.equals(n.id(), id));
}
/**
* Returns a list of sub-state graph nodes.
* @return a list of sub-state graph nodes
*/
public List<SubStateGraphNode> onlySubStateGraphNodes() {
return elements.stream()
.filter(n -> n instanceof SubStateGraphNode)
.map(n -> (SubStateGraphNode) n)
.toList();
}
/**
* Returns a list of nodes excluding sub-state graph nodes.
* @return a list of nodes excluding sub-state graph nodes
*/
public List<Node> exceptSubStateGraphNodes() {
return elements.stream().filter(n -> !(n instanceof SubStateGraphNode)).toList();
}
}
/**
* Container for edges in the graph.
*/
public static class Edges {
/**
* The collection of edges.
*/
public final List<Edge> elements;
/**
* Instantiates a new Edges container with the provided elements.
* @param elements the elements to initialize
*/
public Edges(Collection<Edge> elements) {
this.elements = new LinkedList<>(elements);
}
/**
* Instantiates a new empty Edges container.
*/
public Edges() {
this.elements = new LinkedList<>();
}
/**
* Retrieves the first edge matching the specified source identifier.
* @param sourceId the source identifier to match
* @return an optional containing the matched edge, or empty if none found
*/
public Optional<Edge> edgeBySourceId(String sourceId) {
return elements.stream().filter(e -> Objects.equals(e.sourceId(), sourceId)).findFirst();
}
/**
* Retrieves a list of edges targeting the specified node identifier.
* @param targetId the target identifier to match
* @return a list of edges targeting the specified identifier
*/
public List<Edge> edgesByTargetId(String targetId) {
return elements.stream().filter(e -> e.anyMatchByTargetId(targetId)).toList();
}
}
}
CompileConfig
主要用于配置图的编译过程,包括检查点保存器、中断设置、生命周期监听器等
| 字段名称 | 字段类型 | 描述 |
| saverConfig | SaverConfig | 保存器配置,用于管理检查点保存器,替代旧的checkpointerSaver字段 |
| lifecycleListeners | Deque | 图生命周期监听器队列,用于监听节点执行事件 |
| interruptsBefore | Set | 在指定节点之前发生的中断点集合 |
| interruptsAfter | Set | 在指定节点之后发生的中断点集合 |
| releaseThread | boolean | 线程释放标志,指示是否在执行期间释放线程 |
| observationRegistry | ObservationRegistry | 观察注册表,用于监控和追踪 |
对外暴露方法
| 方法 | 描述 |
| releaseThread | 返回线程释放标志的当前状态 |
| lifecycleListeners | 获取不可变的节点生命周期监听器列表 |
| observationRegistry | 获取用于监控和追踪的观察注册表 |
| interruptsBefore | 返回在指定节点之前发生的中断点集合 |
| interruptsAfter | 返回在指定节点之后发生的中断点集合 |
| checkpointSaver | 丛保存期配置中检索默认检查点保存器 |
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import io.micrometer.observation.ObservationRegistry;
import java.util.Collection;
import java.util.Deque;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant.MEMORY;
import static java.util.Optional.ofNullable;
/**
* class is a configuration container for defining compile settings and behaviors. It
* includes various fields and methods to manage checkpoint savers and interrupts,
* providing both deprecated and current accessors.
*/
public class CompileConfig {
private SaverConfig saverConfig = new SaverConfig().register(MEMORY, new MemorySaver());
private Deque<GraphLifecycleListener> lifecycleListeners = new LinkedBlockingDeque<>(25);
// private BaseCheckpointSaver checkpointSaver; // replaced with SaverConfig
private Set<String> interruptsBefore = Set.of();
private Set<String> interruptsAfter = Set.of();
private boolean releaseThread = false;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
/**
* Returns the current state of the thread release flag.
*
* @see BaseCheckpointSaver#release(RunnableConfig)
* @return true if the thread has been released, false otherwise
*/
public boolean releaseThread() {
return releaseThread;
}
/**
* Gets an unmodifiable list of node lifecycle listeners.
* @return The list of lifecycle listeners.
*/
public Queue<GraphLifecycleListener> lifecycleListeners() {
return lifecycleListeners;
}
/**
* Gets observation registry for monitoring and tracing.
* @return The observation registry instance.
*/
public ObservationRegistry observationRegistry() {
return observationRegistry;
}
/**
* Returns the array of interrupts that will occur before the specified node
* (deprecated).
* @return An array of interruptible nodes.
* @deprecated Use {@link #interruptsBefore()} instead for better immutability and
* type safety.
*/
@Deprecated
public String[] getInterruptBefore() {
return interruptsBefore.toArray(new String[0]);
}
/**
* Returns the array of interrupts that will occur after the specified node
* (deprecated).
* @return An array of interruptible nodes.
* @deprecated Use {@link #interruptsAfter()} instead for better immutability and type
* safety.
*/
@Deprecated
public String[] getInterruptAfter() {
return interruptsAfter.toArray(new String[0]);
}
/**
* Returns the set of interrupts that will occur before the specified node.
* @return An unmodifiable set of interruptible nodes.
*/
public Set<String> interruptsBefore() {
return interruptsBefore;
}
/**
* Returns the set of interrupts that will occur after the specified node.
* @return An unmodifiable set of interruptible nodes.
*/
public Set<String> interruptsAfter() {
return interruptsAfter;
}
/**
* Retrieves a checkpoint saver based on the specified type from the saver
* configuration.
* @param type The type of the checkpoint saver to retrieve.
* @return An Optional containing the checkpoint saver if available; otherwise, empty.
*/
public Optional<BaseCheckpointSaver> checkpointSaver(String type) {
return ofNullable(saverConfig.get(type));
}
/**
* Retrieves the default checkpoint saver from the saver configuration.
* @return An Optional containing the default checkpoint saver if available;
* otherwise, empty.
*/
public Optional<BaseCheckpointSaver> checkpointSaver() {
return ofNullable(saverConfig.get());
}
/**
* Returns a new instance of the builder with default configuration settings.
* @return A new Builder instance.
*/
public static Builder builder() {
return new Builder(new CompileConfig());
}
/**
* Returns a new instance of the builder initialized with the provided configuration.
* @param config The compile configuration to use as a base.
* @return A new Builder instance initialized with the given configuration.
*/
public static Builder builder(CompileConfig config) {
return new Builder(config);
}
/**
* Builder class for creating instances of CompileConfig. It allows setting various
* options such as savers, interrupts, and lifecycle listeners in a fluent manner.
*/
public static class Builder {
private final CompileConfig config;
/**
* Initializes the builder with the provided compile configuration.
* @param config The base configuration to start from.
*/
protected Builder(CompileConfig config) {
this.config = new CompileConfig(config);
}
/**
* Sets whether the thread should be released during execution.
* @param releaseThread Flag indicating whether to release the thread.
* @see BaseCheckpointSaver#release(RunnableConfig)
* @return This builder instance for method chaining.
*/
public Builder releaseThread(boolean releaseThread) {
this.config.releaseThread = releaseThread;
return this;
}
/**
* Sets the observation registry for monitoring and tracing.
* @param observationRegistry The ObservationRegistry to use.
* @return This builder instance for method chaining.
*/
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.config.observationRegistry = observationRegistry;
return this;
}
/**
* Sets the saver configuration for checkpoints.
* @param saverConfig The SaverConfig to use.
* @return This builder instance for method chaining.
*/
public Builder saverConfig(SaverConfig saverConfig) {
this.config.saverConfig = saverConfig;
return this;
}
/**
* Sets individual interrupt points that trigger before node execution using
* varargs.
* @param interruptBefore One or more strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptBefore(String... interruptBefore) {
this.config.interruptsBefore = Set.of(interruptBefore);
return this;
}
/**
* Sets individual interrupt points that trigger after node execution using
* varargs.
* @param interruptAfter One or more strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptAfter(String... interruptAfter) {
this.config.interruptsAfter = Set.of(interruptAfter);
return this;
}
/**
* Sets multiple interrupt points that trigger before node execution from a
* collection.
* @param interruptsBefore Collection of strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptsBefore(Collection<String> interruptsBefore) {
this.config.interruptsBefore = interruptsBefore.stream().collect(Collectors.toUnmodifiableSet());
return this;
}
/**
* Sets multiple interrupt points that trigger after node execution from a
* collection.
* @param interruptsAfter Collection of strings representing interrupt points.
* @return This builder instance for method chaining.
*/
public Builder interruptsAfter(Collection<String> interruptsAfter) {
this.config.interruptsAfter = interruptsAfter.stream().collect(Collectors.toUnmodifiableSet());
return this;
}
/**
* Adds a lifecycle listener to monitor node execution events.
* @param listener The NodeLifecycleListener to add.
* @return This builder instance for method chaining.
*/
public Builder withLifecycleListener(GraphLifecycleListener listener) {
this.config.lifecycleListeners.offer(listener);
return this;
}
/**
* Finalizes the configuration and returns the compiled instance.
* @return The configured CompileConfig object.
*/
public CompileConfig build() {
return config;
}
}
/**
* Default constructor used internally to create a new configuration with default
* settings. Made private to ensure all instances are created through the builder
* pattern.
*/
private CompileConfig() {
}
/**
* Copy constructor to create a new instance based on an existing configuration.
* @param config The configuration to copy.
*/
private CompileConfig(CompileConfig config) {
this.saverConfig = config.saverConfig;
this.interruptsBefore = config.interruptsBefore;
this.interruptsAfter = config.interruptsAfter;
this.releaseThread = config.releaseThread;
this.lifecycleListeners = config.lifecycleListeners;
this.observationRegistry = config.observationRegistry;
}
}
CompiledGraph
图计算框架中的核心组件,代表一个已编译的图结
| 字段名称 | 字段类型 | 描述 |
| stateGraph | StateGraph | 关联原始状态图 |
| keyStrategyMap | Map |
对外暴露的方法
| 方法名称 | 描述 | |
| 图执行方法 | stream | - 无参:创建默认的流式输出 - (Map |
StreamMode 枚举类
- VALUES:值流模式
- SNAPSHOTS:快照流模式
AsyncNodeGenerator 内部类:负责图的异步执行和流式输出处理
核心字段说明:
-
Cursor cursor:游标对象,用于跟踪当前和下一个节点 ID
- String currentNodeId:当前节点 ID
- String nextNodeId:下一个节点 ID
- String resumeFrom:恢复执行的节点 ID
-
int iteration:当前迭代次数
-
RunnableConfig config:运行时配置
-
boolean returnFromEmbed:标记是否从嵌入生成器返回
-
Map<String, Object> currentState:当前状态的 Map 表示
-
OverAllState overAllState:封装的 OverAllState 对象,提供更丰富的状态操作
| 核心方法 | 描述 |
| next | 负责执行图的下一步操作 |
| evaluateAction | 评估并执行节点操作 |
| nextNodeId | 根据当前节点和状态确认下一个节点 |
| getEmbedGenerator | 从部分状态中提取嵌入式生成器 |
| processGeneratorOutput | 处理生成器输出数据 |
package com.alibaba.cloud.ai.graph;
import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.Command;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.Checkpoint;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.exception.RunnableErrors;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import com.alibaba.cloud.ai.graph.streaming.AsyncGeneratorUtils;
import com.alibaba.cloud.ai.graph.utils.LifeListenerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.alibaba.cloud.ai.graph.StateGraph.*;
import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.stream.Collectors.toList;
/**
* The type Compiled graph.
*/
public class CompiledGraph {
private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
/**
* The enum Stream mode.
*/
public enum StreamMode {
/**
* Values stream mode.
*/
VALUES,
/**
* Snapshots stream mode.
*/
SNAPSHOTS
}
/**
* The State graph.
*/
public final StateGraph stateGraph;
private final Map<String, KeyStrategy> keyStrategyMap;
/**
* The Nodes.
*/
final Map<String, AsyncNodeActionWithConfig> nodes = new LinkedHashMap<>();
/**
* The Edges.
*/
final Map<String, EdgeValue> edges = new LinkedHashMap<>();
private final ProcessedNodesEdgesAndConfig processedData;
private int maxIterations = 25;
/**
* The Compile config.
*/
public final CompileConfig compileConfig;
private static String INTERRUPTAFTER = "INTERRUPTED";
/**
* Constructs a CompiledGraph with the given StateGraph.
* @param stateGraph the StateGraph to be used in this CompiledGraph
* @param compileConfig the compile config
* @throws GraphStateException the graph state exception
*/
protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
this.stateGraph = stateGraph;
this.keyStrategyMap = Objects.isNull(stateGraph.getOverAllStateFactory())
? stateGraph.getKeyStrategyFactory()
.apply()
.entrySet()
.stream()
.map(e -> Map.entry(e.getKey(), e.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
: stateGraph.getOverAllStateFactory().create().keyStrategies();
this.processedData = ProcessedNodesEdgesAndConfig.process(stateGraph, compileConfig);
// CHECK INTERRUPTIONS
for (String interruption : processedData.interruptsBefore()) {
if (!processedData.nodes().anyMatchById(interruption)) {
throw Errors.interruptionNodeNotExist.exception(interruption);
}
}
for (String interruption : processedData.interruptsAfter()) {
if (!processedData.nodes().anyMatchById(interruption)) {
throw Errors.interruptionNodeNotExist.exception(interruption);
}
}
// RE-CREATE THE EVENTUALLY UPDATED COMPILE CONFIG
this.compileConfig = CompileConfig.builder(compileConfig)
.interruptsBefore(processedData.interruptsBefore())
.interruptsAfter(processedData.interruptsAfter())
.build();
// EVALUATES NODES
for (var n : processedData.nodes().elements) {
var factory = n.actionFactory();
Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()));
nodes.put(n.id(), factory.apply(compileConfig));
}
// EVALUATE EDGES
for (var e : processedData.edges().elements) {
var targets = e.targets();
if (targets.size() == 1) {
edges.put(e.sourceId(), targets.get(0));
}
else {
Supplier<Stream<EdgeValue>> parallelNodeStream = () -> targets.stream()
.filter(target -> nodes.containsKey(target.id()));
var parallelNodeEdges = parallelNodeStream.get()
.map(target -> new Edge(target.id()))
.filter(ee -> processedData.edges().elements.contains(ee))
.map(ee -> processedData.edges().elements.indexOf(ee))
.map(index -> processedData.edges().elements.get(index))
.toList();
var parallelNodeTargets = parallelNodeEdges.stream()
.map(ee -> ee.target().id())
.collect(Collectors.toSet());
if (parallelNodeTargets.size() > 1) {
var conditionalEdges = parallelNodeEdges.stream()
.filter(ee -> ee.target().value() != null)
.toList();
if (!conditionalEdges.isEmpty()) {
throw Errors.unsupportedConditionalEdgeOnParallelNode.exception(e.sourceId(),
conditionalEdges.stream().map(Edge::sourceId).toList());
}
throw Errors.illegalMultipleTargetsOnParallelNode.exception(e.sourceId(), parallelNodeTargets);
}
var actions = parallelNodeStream.get()
// .map( target -> nodes.remove(target.id()) )
.map(target -> nodes.get(target.id()))
.toList();
var parallelNode = new ParallelNode(e.sourceId(), actions, keyStrategyMap, compileConfig);
nodes.put(parallelNode.id(), parallelNode.actionFactory().apply(compileConfig));
edges.put(e.sourceId(), new EdgeValue(parallelNode.id()));
edges.put(parallelNode.id(), new EdgeValue(parallelNodeTargets.iterator().next()));
}
}
}
public Collection<StateSnapshot> getStateHistory(RunnableConfig config) {
BaseCheckpointSaver saver = compileConfig.checkpointSaver()
.orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));
return saver.list(config)
.stream()
.map(checkpoint -> StateSnapshot.of(keyStrategyMap, checkpoint, config, stateGraph.getStateFactory()))
.collect(toList());
}
/**
* Same of {@link #stateOf(RunnableConfig)} but throws an IllegalStateException if
* checkpoint is not found.
* @param config the RunnableConfig
* @return the StateSnapshot of the given RunnableConfig
* @throws IllegalStateException if the saver is not defined, or no checkpoint is
* found
*/
public StateSnapshot getState(RunnableConfig config) {
return stateOf(config).orElseThrow(() -> (new IllegalStateException("Missing Checkpoint!")));
}
/**
* Get the StateSnapshot of the given RunnableConfig.
* @param config the RunnableConfig
* @return an Optional of StateSnapshot of the given RunnableConfig
* @throws IllegalStateException if the saver is not defined
*/
public Optional<StateSnapshot> stateOf(RunnableConfig config) {
BaseCheckpointSaver saver = compileConfig.checkpointSaver()
.orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));
return saver.get(config)
.map(checkpoint -> StateSnapshot.of(keyStrategyMap, checkpoint, config, stateGraph.getStateFactory()));
}
/**
* Update the state of the graph with the given values. If asNode is given, it will be
* used to determine the next node to run. If not given, the next node will be
* determined by the state graph.
* @param config the RunnableConfig containing the graph state
* @param values the values to be updated
* @param asNode the node id to be used for the next node. can be null
* @return the updated RunnableConfig
* @throws Exception when something goes wrong
*/
public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values, String asNode)
throws Exception {
BaseCheckpointSaver saver = compileConfig.checkpointSaver()
.orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));
// merge values with checkpoint values
Checkpoint branchCheckpoint = saver.get(config)
.map(Checkpoint::new)
.map(cp -> cp.updateState(values, keyStrategyMap))
.orElseThrow(() -> (new IllegalStateException("Missing Checkpoint!")));
String nextNodeId = null;
if (asNode != null) {
var nextNodeCommand = nextNodeId(asNode, branchCheckpoint.getState(), config);
nextNodeId = nextNodeCommand.gotoNode();
branchCheckpoint = branchCheckpoint.updateState(nextNodeCommand.update(), keyStrategyMap);
}
// update checkpoint in saver
RunnableConfig newConfig = saver.put(config, branchCheckpoint);
return RunnableConfig.builder(newConfig).checkPointId(branchCheckpoint.getId()).nextNode(nextNodeId).build();
}
/***
* Update the state of the graph with the given values.
* @param config the RunnableConfig containing the graph state
* @param values the values to be updated
* @return the updated RunnableConfig
* @throws Exception when something goes wrong
*/
public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values) throws Exception {
return updateState(config, values, null);
}
/**
* Sets the maximum number of iterations for the graph execution.
* @param maxIterations the maximum number of iterations
* @throws IllegalArgumentException if maxIterations is less than or equal to 0
*/
public void setMaxIterations(int maxIterations) {
if (maxIterations <= 0) {
throw new IllegalArgumentException("maxIterations must be > 0!");
}
this.maxIterations = maxIterations;
}
private Command nextNodeId(EdgeValue route, Map<String, Object> state, String nodeId, RunnableConfig config)
throws Exception {
if (route == null) {
throw RunnableErrors.missingEdge.exception(nodeId);
}
if (route.id() != null) {
return new Command(route.id(), state);
}
if (route.value() != null) {
OverAllState derefState = stateGraph.getStateFactory().apply(state);
var command = route.value().action().apply(derefState, config).get();
var newRoute = command.gotoNode();
String result = route.value().mappings().get(newRoute);
if (result == null) {
throw RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
}
var currentState = OverAllState.updateState(state, command.update(), keyStrategyMap);
return new Command(result, currentState);
}
throw RunnableErrors.executionError.exception(format("invalid edge value for nodeId: [%s] !", nodeId));
}
/**
* Determines the next node ID based on the current node ID and state.
* @param nodeId the current node ID
* @param state the current state
* @return the next node command
* @throws Exception if there is an error determining the next node ID
*/
private Command nextNodeId(String nodeId, Map<String, Object> state, RunnableConfig config) throws Exception {
return nextNodeId(edges.get(nodeId), state, nodeId, config);
}
private Command getEntryPoint(Map<String, Object> state, RunnableConfig config) throws Exception {
var entryPoint = this.edges.get(START);
return nextNodeId(entryPoint, state, "entryPoint", config);
}
private boolean shouldInterruptBefore(String nodeId, String previousNodeId) {
if (previousNodeId == null) { // FIX RESUME ERROR
return false;
}
return compileConfig.interruptsBefore().contains(nodeId);
}
private boolean shouldInterruptAfter(String nodeId, String previousNodeId) {
if (nodeId == null || Objects.equals(nodeId, previousNodeId)) { // FIX RESUME
// ERROR
return false;
}
return (compileConfig.interruptBeforeEdge() && Objects.equals(nodeId, INTERRUPTAFTER))
|| compileConfig.interruptsAfter().contains(nodeId);
}
private Optional<Checkpoint> addCheckpoint(RunnableConfig config, String nodeId, Map<String, Object> state,
String nextNodeId) throws Exception {
if (compileConfig.checkpointSaver().isPresent()) {
var cp = Checkpoint.builder().nodeId(nodeId).state(cloneState(state)).nextNodeId(nextNodeId).build();
compileConfig.checkpointSaver().get().put(config, cp);
return Optional.of(cp);
}
return Optional.empty();
}
/**
* Gets initial state.
* @param inputs the inputs
* @param config the config
* @return the initial state
*/
Map<String, Object> getInitialState(Map<String, Object> inputs, RunnableConfig config) {
return compileConfig.checkpointSaver()
.flatMap(saver -> saver.get(config))
.map(cp -> OverAllState.updateState(cp.getState(), inputs, keyStrategyMap))
.orElseGet(() -> OverAllState.updateState(new HashMap<>(), inputs, keyStrategyMap));
}
/**
* Clone state over all state.
* @param data the data
* @return the over all state
*/
OverAllState cloneState(Map<String, Object> data) throws IOException, ClassNotFoundException {
return stateGraph.getStateSerializer().cloneObject(data);
}
/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
*/
public AsyncGenerator<NodeOutput> stream(Map<String, Object> inputs, RunnableConfig config)
throws GraphRunnerException {
Objects.requireNonNull(config, "config cannot be null");
final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(stateCreate(inputs), config);
return new AsyncGenerator.WithEmbed<>(generator);
}
/**
* Stream async generator.
* @param overAllState the over all state
* @param config the config
* @return the async generator
*/
public AsyncGenerator<NodeOutput> streamFromInitialNode(OverAllState overAllState, RunnableConfig config)
throws GraphRunnerException {
Objects.requireNonNull(config, "config cannot be null");
final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(overAllState, config);
return new AsyncGenerator.WithEmbed<>(generator);
}
/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
* @param inputs the input map
* @return an AsyncGenerator stream of NodeOutput
*/
public AsyncGenerator<NodeOutput> stream(Map<String, Object> inputs) throws GraphRunnerException {
return this.streamFromInitialNode(stateCreate(inputs), RunnableConfig.builder().build());
}
/**
* Stream async generator.
* @return the async generator
*/
public AsyncGenerator<NodeOutput> stream() throws GraphRunnerException {
return this.stream(Map.of(), RunnableConfig.builder().build());
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
* @param inputs the input map
* @param config the invoke configuration
* @return an Optional containing the final state if present, otherwise an empty
* Optional
*/
public Optional<OverAllState> invoke(Map<String, Object> inputs, RunnableConfig config)
throws GraphRunnerException {
return stream(inputs, config).stream().reduce((a, b) -> b).map(NodeOutput::state);
}
/**
* Invoke optional.
* @param overAllState the over all state
* @param config the config
* @return the optional
*/
public Optional<OverAllState> invoke(OverAllState overAllState, RunnableConfig config) throws GraphRunnerException {
return streamFromInitialNode(overAllState, config).stream().reduce((a, b) -> b).map(NodeOutput::state);
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
* @param inputs the input map
* @return an Optional containing the final state if present, otherwise an empty
* Optional
*/
public Optional<OverAllState> invoke(Map<String, Object> inputs) throws GraphRunnerException {
return this.invoke(stateCreate(inputs), RunnableConfig.builder().build());
}
private OverAllState stateCreate(Map<String, Object> inputs) {
// Creates a new OverAllState instance based on the presence of an
// OverAllStateFactory in the stateGraph.
// If no factory is present, constructs a new state using key strategies from
// the
// graph and provided input data.
// If a factory exists, uses it to create the state and applies the input data.
return Objects.isNull(stateGraph.getOverAllStateFactory()) ? OverAllStateBuilder.builder()
.withKeyStrategies(stateGraph.getKeyStrategyFactory().apply())
.withData(inputs)
.build() : stateGraph.getOverAllStateFactory().create().input(inputs);
}
/**
* Experimental API
* @param feedback the feedback
* @param config the config
* @return the optional
*/
public Optional<OverAllState> resume(OverAllState.HumanFeedback feedback, RunnableConfig config)
throws GraphRunnerException {
StateSnapshot stateSnapshot = this.getState(config);
OverAllState resumeState = stateCreate(stateSnapshot.state().data());
resumeState.withResume();
resumeState.withHumanFeedback(feedback);
return this.invoke(resumeState, config);
}
/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
*/
public AsyncGenerator<NodeOutput> streamSnapshots(Map<String, Object> inputs, RunnableConfig config)
throws GraphRunnerException {
Objects.requireNonNull(config, "config cannot be null");
final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(stateCreate(inputs),
config.withStreamMode(StreamMode.SNAPSHOTS));
return new AsyncGenerator.WithEmbed<>(generator);
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @param printConditionalEdges whether to print conditional edges
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
String content = type.generator.generate(processedData.nodes(), processedData.edges(), title,
printConditionalEdges);
return new GraphRepresentation(type, content);
}
/**
* Get the last StateSnapshot of the given RunnableConfig.
* @param config - the RunnableConfig
* @return the last StateSnapshot of the given RunnableConfig if any
*/
Optional<StateSnapshot> lastStateOf(RunnableConfig config) {
return getStateHistory(config).stream().findFirst();
}
/**
* Generates a drawable graph representation of the state graph.
* @param type the type of graph representation to generate
* @param title the title of the graph
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
String content = type.generator.generate(processedData.nodes(), processedData.edges(), title, true);
return new GraphRepresentation(type, content);
}
/**
* Generates a drawable graph representation of the state graph with default title.
* @param type the type of graph representation to generate
* @return a diagram code of the state graph
*/
public GraphRepresentation getGraph(GraphRepresentation.Type type) {
return getGraph(type, "Graph Diagram", true);
}
/**
* Async Generator for streaming outputs.
*
* @param <Output> the type of the output
*/
public class AsyncNodeGenerator<Output extends NodeOutput> implements AsyncGenerator<Output> {
final Cursor cursor;
// String currentNodeId;
// String nextNodeId;
int iteration = 0;
final RunnableConfig config;
volatile boolean returnFromEmbed = false;
Map<String, Object> currentState;
/**
* The Over all state.
*/
OverAllState overAllState;
static class Cursor {
private String currentNodeId;
private String nextNodeId;
private String resumeFrom;
Cursor() {
currentNodeId = START;
nextNodeId = null;
resumeFrom = null;
}
Cursor(Checkpoint cp) {
currentNodeId = null;
nextNodeId = cp.getNextNodeId();
resumeFrom = cp.getNodeId();
}
void reset() {
currentNodeId = null;
nextNodeId = null;
resumeFrom = null;
}
boolean isResumed() {
return resumeFrom != null;
}
String nextNodeId() {
return nextNodeId;
}
void setNextNodeId(String value) {
nextNodeId = value;
}
String currentNodeId() {
return currentNodeId;
}
void setCurrentNodeId(String value) {
currentNodeId = value;
}
String resumeFrom() {
return resumeFrom;
}
void setResumeFrom(String value) {
resumeFrom = value;
}
}
/**
* Instantiates a new Async node generator.
* @param overAllState the over all state
* @param config the config
*/
protected AsyncNodeGenerator(OverAllState overAllState, RunnableConfig config) throws GraphRunnerException {
if (overAllState.isResume()) {
log.trace("RESUME REQUEST");
BaseCheckpointSaver saver = compileConfig.checkpointSaver()
.orElseThrow(() -> (new IllegalStateException(
"inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));
Checkpoint startCheckpoint = saver.get(config)
.orElseThrow(() -> (new IllegalStateException("Resume request without a saved checkpoint!")));
this.currentState = startCheckpoint.getState();
this.cursor = new Cursor(startCheckpoint);
this.config = config.withCheckPointId(null);
this.overAllState = overAllState.input(this.currentState);
log.trace("RESUME FROM {}", startCheckpoint.getNodeId());
// this.nextNodeId = startCheckpoint.getNextNodeId();
// this.currentNodeId = null;
// log.trace("RESUME FROM {}", startCheckpoint.getNodeId());
}
else {
log.trace("START");
Map<String, Object> inputs = overAllState.data();
boolean verify = overAllState.keyVerify();
if (!CollectionUtils.isEmpty(inputs) && !verify) {
throw RunnableErrors.initializationError.exception(Arrays.toString(inputs.keySet().toArray()));
}
// patch for backward support of AppendableValue
this.currentState = getInitialState(inputs, config);
this.overAllState = overAllState.input(currentState);
this.cursor = new Cursor();
// this.nextNodeId = null;
// this.currentNodeId = START;
this.config = config;
}
}
private Optional<BaseCheckpointSaver.Tag> releaseThread() throws Exception {
if (compileConfig.releaseThread() && compileConfig.checkpointSaver().isPresent()) {
return Optional.of(compileConfig.checkpointSaver().get().release(config));
}
return Optional.empty();
}
/**
* Build node output output.
* @param nodeId the node id
* @return the output
*/
@SuppressWarnings("unchecked")
protected Output buildNodeOutput(String nodeId) {
return (Output) NodeOutput.of(nodeId, cloneState(currentState));
}
/**
* Clone state over all state.
* @param data the data
* @return the over all state
*/
OverAllState cloneState(Map<String, Object> data) {
return new OverAllState(data, keyStrategyMap, overAllState.isResume());
}
/**
* Build state snapshot output.
* @param checkpoint the checkpoint
* @return the output
*/
@SuppressWarnings("unchecked")
protected Output buildStateSnapshot(Checkpoint checkpoint) {
return (Output) StateSnapshot.of(keyStrategyMap, checkpoint, config,
stateGraph.getStateSerializer().stateFactory());
}
/**
* Gets embed generator from partial state.
* @param partialState the partial state containing generator instances
* @return an Optional containing Data with the generator if found, empty
* otherwise
*/
private Optional<Data<Output>> getEmbedGenerator(Map<String, Object> partialState) {
// Extract all AsyncGenerator instances
List<AsyncGenerator<Output>> asyncNodeGenerators = new ArrayList<>();
var generatorEntries = partialState.entrySet().stream().filter(e -> {
// Fixed when parallel nodes return asynchronous generating the same key
Object value = e.getValue();
if (value instanceof AsyncGenerator) {
return true;
}
if (value instanceof Collection collection) {
collection.forEach(o -> {
if (o instanceof AsyncGenerator<?>) {
asyncNodeGenerators.add((AsyncGenerator<Output>) o);
}
});
}
return false;
}).collect(Collectors.toList());
if (generatorEntries.isEmpty() && asyncNodeGenerators.isEmpty()) {
return Optional.empty();
}
// Log information about found generators
if (generatorEntries.size() > 1) {
log.debug("Multiple generators found: {} - keys: {}", generatorEntries.size(),
generatorEntries.stream().map(Map.Entry::getKey).collect(Collectors.joining(", ")));
}
// Create appropriate generator (single or merged)
AsyncGenerator<Output> generator = AsyncGeneratorUtils.createAppropriateGenerator(generatorEntries,
asyncNodeGenerators, keyStrategyMap);
// Create data processing logic for the generator
return Optional.of(Data.composeWith(generator.map(n -> {
n.setSubGraph(true);
return n;
}), data -> processGeneratorOutput(data, partialState, generatorEntries)));
}
/**
* Processes output data from generator.
* @param data output data from generator
* @param partialState partial state
* @param generatorEntries generator entries list
* @throws Exception if an error occurs during processing
*/
@SuppressWarnings("unchecked")
private void processGeneratorOutput(Object data, Map<String, Object> partialState,
List<Map.Entry<String, Object>> generatorEntries) throws Exception {
// Remove all generators
Map<String, Object> partialStateWithoutGenerators = new HashMap<>();
for (Map.Entry<String, Object> entry : partialState.entrySet()) {
if (entry.getValue() instanceof AsyncGenerator) {
continue; // Skip top-level AsyncGenerator values
}
if (entry.getValue() instanceof Collection<?>) {
Collection<?> collection = (Collection<?>) entry.getValue();
ArrayList<Object> filteredCollection = new ArrayList<>();
for (Object item : collection) {
if (!(item instanceof AsyncGenerator)) {
filteredCollection.add(item);
}
}
if (!filteredCollection.isEmpty()) {
partialStateWithoutGenerators.put(entry.getKey(), filteredCollection);
}
}
else {
// Keep the entry if it's not an AsyncGenerator and not a collection
// containing it
partialStateWithoutGenerators.put(entry.getKey(), entry.getValue());
}
}
// Update state with partial state without generators
var intermediateState = OverAllState.updateState(currentState, partialStateWithoutGenerators,
keyStrategyMap);
currentState = intermediateState;
overAllState.updateState(partialStateWithoutGenerators);
// If data is not null and is a Map, update state with it
if (data != null) {
if (data instanceof Map<?, ?>) {
currentState = OverAllState.updateState(intermediateState, (Map<String, Object>) data,
keyStrategyMap);
overAllState.updateState((Map<String, Object>) data);
if (log.isDebugEnabled() && generatorEntries.size() > 1) {
log.debug("Updated state with data keys: {}",
((Map<String, Object>) data).keySet().stream().collect(Collectors.joining(", ")));
}
}
else {
throw new IllegalArgumentException("Embedded generator must return a Map");
}
}
// Get next node command
var nextNodeCommand = nextNodeId(cursor.currentNodeId(), currentState, config);
cursor.setNextNodeId(nextNodeCommand.gotoNode());
currentState = nextNodeCommand.update();
returnFromEmbed = true;
}
private CompletableFuture<Data<Output>> evaluateAction(AsyncNodeActionWithConfig action,
OverAllState withState) {
try {
doListeners(NODEBEFORE, null);
return action.apply(withState, config).thenApply(updateState -> {
try {
if (action instanceof CommandNode.AsyncCommandNodeActionWithConfig) {
AsyncCommandAction commandAction = (AsyncCommandAction) updateState.get("command");
Command command = commandAction.apply(withState, config).join();
this.currentState = OverAllState.updateState(currentState, command.update(),
keyStrategyMap);
this.overAllState.updateState(command.update());
cursor.setNextNodeId(command.gotoNode());
return Data.of(getNodeOutput());
}
Optional<Data<Output>> embed = getEmbedGenerator(updateState);
if (embed.isPresent()) {
return embed.get();
}
this.currentState = OverAllState.updateState(currentState, updateState, keyStrategyMap);
this.overAllState.updateState(updateState);
if (compileConfig.interruptBeforeEdge()
&& compileConfig.interruptsAfter().contains(cursor.currentNodeId())) {
// nextNodeId = INTERRUPTAFTER;
cursor.setNextNodeId(INTERRUPTAFTER);
}
else {
var nextNodeCommand = nextNodeId(cursor.currentNodeId(), currentState, config);
// nextNodeId = nextNodeCommand.gotoNode();
cursor.setNextNodeId(nextNodeCommand.gotoNode());
currentState = nextNodeCommand.update();
}
return Data.of(getNodeOutput());
}
catch (Exception e) {
throw new CompletionException(e);
}
}).whenComplete((outputData, throwable) -> doListeners(NODEAFTER, null));
}
catch (Exception e) {
return failedFuture(e);
}
}
/**
* Determines the next node ID based on the current node ID and state.
* @param nodeId the current node ID
* @param state the current state
* @return the next node command
* @throws Exception if there is an error determining the next node ID
*/
private Command nextNodeId(String nodeId, Map<String, Object> state, RunnableConfig config) throws Exception {
return nextNodeId(edges.get(nodeId), state, nodeId, config);
}
private Command nextNodeId(EdgeValue route, Map<String, Object> state, String nodeId, RunnableConfig config)
throws Exception {
if (route == null) {
throw RunnableErrors.missingEdge.exception(nodeId);
}
if (route.id() != null) {
return new Command(route.id(), state);
}
if (route.value() != null) {
var command = route.value().action().apply(this.overAllState, config).get();
var newRoute = command.gotoNode();
String result = route.value().mappings().get(newRoute);
if (result == null) {
throw RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
}
var currentState = OverAllState.updateState(state, command.update(), keyStrategyMap);
this.overAllState.updateState(command.update());
return new Command(result, currentState);
}
throw RunnableErrors.executionError.exception(format("invalid edge value for nodeId: [%s] !", nodeId));
}
private CompletableFuture<Output> getNodeOutput() throws Exception {
Optional<Checkpoint> cp = addCheckpoint(config, cursor.currentNodeId(), currentState, cursor.nextNodeId());
return completedFuture((cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS)
? buildStateSnapshot(cp.get()) : buildNodeOutput(cursor.currentNodeId()));
}
@Override
public Data<Output> next() {
try {
// GUARD: CHECK MAX ITERATION REACHED
if (++iteration > maxIterations) {
// log.warn( "Maximum number of iterations ({}) reached!",
// maxIterations);
return Data.error(new IllegalStateException(
format("Maximum number of iterations (%d) reached!", maxIterations)));
}
// GUARD: CHECK IF IT IS END
if (cursor.nextNodeId() == null && cursor.currentNodeId() == null) {
return releaseThread().map(Data::<Output>done).orElseGet(() -> Data.done(currentState));
}
// IS IT A RESUME FROM EMBED ?
if (returnFromEmbed) {
final CompletableFuture<Output> future = getNodeOutput();
returnFromEmbed = false;
return Data.of(future);
}
if (cursor.currentNodeId() != null && config.isInterrupted(cursor.currentNodeId())) {
config.withNodeResumed(cursor.currentNodeId());
return Data.done(currentState);
}
if (START.equals(cursor.currentNodeId())) {
doListeners(START, null);
var nextNodeCommand = getEntryPoint(currentState, config);
cursor.setNextNodeId(nextNodeCommand.gotoNode());
currentState = nextNodeCommand.update();
var cp = addCheckpoint(config, START, currentState, cursor.nextNodeId());
var output = (cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS)
? buildStateSnapshot(cp.get()) : buildNodeOutput(cursor.currentNodeId());
cursor.setCurrentNodeId(cursor.nextNodeId());
// currentNodeId = nextNodeId;
return Data.of(output);
}
if (END.equals(cursor.nextNodeId())) {
cursor.reset();
doListeners(END, null);
// nextNodeId = null;
// currentNodeId = null;
return Data.of(buildNodeOutput(END));
}
if (cursor.isResumed()) {
if (compileConfig.interruptBeforeEdge() && Objects.equals(cursor.nextNodeId(), INTERRUPTAFTER)) {
var nextNodeCommand = nextNodeId(cursor.resumeFrom(), currentState, config);
// nextNodeId = nextNodeCommand.gotoNode();
cursor.setNextNodeId(nextNodeCommand.gotoNode());
currentState = nextNodeCommand.update();
cursor.setCurrentNodeId(null);
}
cursor.setResumeFrom(null);
}
// check on previous node
if (shouldInterruptAfter(cursor.currentNodeId(), cursor.nextNodeId())) {
return Data.done(cursor.currentNodeId());
}
if (shouldInterruptBefore(cursor.nextNodeId(), cursor.currentNodeId())) {
return Data.done(cursor.nextNodeId());
}
cursor.setCurrentNodeId(cursor.nextNodeId());
var action = nodes.get(cursor.currentNodeId());
if (action == null)
throw RunnableErrors.missingNode.exception(cursor.currentNodeId());
return evaluateAction(action, this.overAllState).get();
}
catch (Exception e) {
doListeners(ERROR, e);
log.error(e.getMessage(), e);
return Data.error(e);
}
}
private void doListeners(String scene, Exception e) {
Deque<GraphLifecycleListener> listeners = new LinkedBlockingDeque<>(compileConfig.lifecycleListeners());
LifeListenerUtil.processListenersLIFO(this.cursor.currentNodeId(), listeners, this.currentState,
this.config, scene, e);
}
}
}
/**
* The type Processed nodes edges and config.
*/
record ProcessedNodesEdgesAndConfig(StateGraph.Nodes nodes, StateGraph.Edges edges, Set<String> interruptsBefore,
Set<String> interruptsAfter) {
/**
* Instantiates a new Processed nodes edges and config.
* @param stateGraph the state graph
* @param config the config
*/
ProcessedNodesEdgesAndConfig(StateGraph stateGraph, CompileConfig config) {
this(stateGraph.nodes, stateGraph.edges, config.interruptsBefore(), config.interruptsAfter());
}
/**
* Process processed nodes edges and config.
* @param stateGraph the state graph
* @param config the config
* @return the processed nodes edges and config
* @throws GraphStateException the graph state exception
*/
static ProcessedNodesEdgesAndConfig process(StateGraph stateGraph, CompileConfig config)
throws GraphStateException {
var subgraphNodes = stateGraph.nodes.onlySubStateGraphNodes();
if (subgraphNodes.isEmpty()) {
return new ProcessedNodesEdgesAndConfig(stateGraph, config);
}
var interruptsBefore = config.interruptsBefore();
var interruptsAfter = config.interruptsAfter();
var nodes = new StateGraph.Nodes(stateGraph.nodes.exceptSubStateGraphNodes());
var edges = new StateGraph.Edges(stateGraph.edges.elements);
for (var subgraphNode : subgraphNodes) {
var sgWorkflow = subgraphNode.subGraph();
ProcessedNodesEdgesAndConfig processedSubGraph = process(sgWorkflow, config);
StateGraph.Nodes processedSubGraphNodes = processedSubGraph.nodes;
StateGraph.Edges processedSubGraphEdges = processedSubGraph.edges;
//
// Process START Node
//
var sgEdgeStart = processedSubGraphEdges.edgeBySourceId(START).orElseThrow();
if (sgEdgeStart.isParallel()) {
throw new GraphStateException("subgraph not support start with parallel branches yet!");
}
var sgEdgeStartTarget = sgEdgeStart.target();
if (sgEdgeStartTarget.id() == null) {
throw new GraphStateException(format("the target for node '%s' is null!", subgraphNode.id()));
}
var sgEdgeStartRealTargetId = subgraphNode.formatId(sgEdgeStartTarget.id());
// Process Interruption (Before) Subgraph(s)
interruptsBefore = interruptsBefore.stream()
.map(interrupt -> Objects.equals(subgraphNode.id(), interrupt) ? sgEdgeStartRealTargetId : interrupt)
.collect(Collectors.toUnmodifiableSet());
var edgesWithSubgraphTargetId = edges.edgesByTargetId(subgraphNode.id());
if (edgesWithSubgraphTargetId.isEmpty()) {
throw new GraphStateException(
format("the node '%s' is not present as target in graph!", subgraphNode.id()));
}
for (var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId) {
var newEdge = edgeWithSubgraphTargetId.withSourceAndTargetIdsUpdated(subgraphNode, Function.identity(),
id -> new EdgeValue((Objects.equals(id, subgraphNode.id())
? subgraphNode.formatId(sgEdgeStartTarget.id()) : id)));
edges.elements.remove(edgeWithSubgraphTargetId);
edges.elements.add(newEdge);
}
//
// Process END Nodes
//
var sgEdgesEnd = processedSubGraphEdges.edgesByTargetId(END);
var edgeWithSubgraphSourceId = edges.edgeBySourceId(subgraphNode.id()).orElseThrow();
if (edgeWithSubgraphSourceId.isParallel()) {
throw new GraphStateException("subgraph not support routes to parallel branches yet!");
}
// Process Interruption (After) Subgraph(s)
if (interruptsAfter.contains(subgraphNode.id())) {
var exceptionMessage = (edgeWithSubgraphSourceId.target()
.id() == null) ? "'interruption after' on subgraph is not supported yet!" : format(
"'interruption after' on subgraph is not supported yet! consider to use 'interruption before' node: '%s'",
edgeWithSubgraphSourceId.target().id());
throw new GraphStateException(exceptionMessage);
}
sgEdgesEnd.stream()
.map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
id -> (Objects.equals(id, END) ? edgeWithSubgraphSourceId.target()
: new EdgeValue(subgraphNode.formatId(id)))))
.forEach(edges.elements::add);
edges.elements.remove(edgeWithSubgraphSourceId);
//
// Process edges
//
processedSubGraphEdges.elements.stream()
.filter(e -> !Objects.equals(e.sourceId(), START))
.filter(e -> !e.anyMatchByTargetId(END))
.map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
id -> new EdgeValue(subgraphNode.formatId(id))))
.forEach(edges.elements::add);
//
// Process nodes
//
processedSubGraphNodes.elements.stream().map(n -> {
if (n instanceof CommandNode commandNode) {
Map<String, String> mappings = commandNode.getMappings();
HashMap<String, String> newMappings = new HashMap<>();
mappings.forEach((key, value) -> {
newMappings.put(key, subgraphNode.formatId(value));
});
return new CommandNode(subgraphNode.formatId(n.id()),
AsyncCommandAction.nodeasync((state, config1) -> {
Command command = commandNode.getAction().apply(state, config1).join();
String NewGoToNode = subgraphNode.formatId(command.gotoNode());
return new Command(NewGoToNode, command.update());
}), newMappings);
}
return n.withIdUpdated(subgraphNode::formatId);
}).forEach(nodes.elements::add);
}
return new ProcessedNodesEdgesAndConfig(nodes, edges, interruptsBefore, interruptsAfter);
}
}
Node
图结构节点核心类,用于定义图中的单个节点,包括节点的标识符和可选的操作工厂
| 字段名称 | 字段类型 | 描述 |
| id | String | 节点的唯一标识符 |
| actionFactory | ActionFactory | 用于创建节点执行的动作 |
对外暴露方法
| 方法名称 | 描述 |
| isParallel | 检查节点是否为并行节点,当前实现总为false |
| withIdUpdated | 返回一个新节点实例,其ID经过指定函数转换 |
package com.alibaba.cloud.ai.graph.internal.node;
import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import java.util.Objects;
import java.util.function.Function;
import static java.lang.String.format;
/**
* Represents a node in a graph, characterized by a unique identifier and a factory for
* creating actions to be executed by the node. This is a generic record where the state
* type is specified by the type parameter {@code State}.
*
* {@link OverAllState}.
*
*/
public class Node {
public interface ActionFactory {
AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
}
private final String id;
private final ActionFactory actionFactory;
public Node(String id, ActionFactory actionFactory) {
this.id = id;
this.actionFactory = actionFactory;
}
/**
* Constructor that accepts only the `id` and sets `actionFactory` to null.
* @param id the unique identifier for the node
*/
public Node(String id) {
this(id, null);
}
/**
* id
* @return the unique identifier for the node.
*/
public String id() {
return id;
}
/**
* actionFactory
* @return a factory function that takes a {@link CompileConfig} and returns an
* {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
*/
public ActionFactory actionFactory() {
return actionFactory;
}
public boolean isParallel() {
// return id.startsWith(PARALLELPREFIX);
return false;
}
public Node withIdUpdated(Function<String, String> newId) {
return new Node(newId.apply(id), actionFactory);
}
/**
* Checks if this node is equal to another object.
* @param o the object to compare with
* @return true if this node is equal to the specified object, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null)
return false;
if (o instanceof Node node) {
return Objects.equals(id, node.id);
}
return false;
}
/**
* Returns the hash code value for this node.
* @return the hash code value for this node
*/
@Override
public int hashCode() {
return Objects.hash(id);
}
@Override
public String toString() {
return format("Node(%s,%s)", id, actionFactory != null ? "action" : "null");
}
}
Edge
用于定义节点之间的连接关系,是一个 record 类
| 字段名称 | 字段类型 | 描述 |
| sourceId | String | 源节点ID,表示边的起始节点 |
| targets | List | 边的目标节点或条件 |
| 方法名称 | 描述 | |
| 构造 | Edge | - (String id):创建一个只有源节点的边 - (String sourceId, EdgeValue target):创建一个从源节点到单个目标节点的边 - (String sourceId, List targets):创建一个从源节点到多个目标节点的边 |
| 边的判断 | isParallel | 判断是否为并行边(目标节点数量大于1) |
| target | 单个目标节点,如果为并行边则抛出异常 | |
| anyMatchByTargetId | 检查目标节点中是否包含指定的目标节点ID | |
| withSourceAndTargetIdsUpdated | 更新源节点 ID 和目标节点值,返回新的 Edge 实例 | |
| validate | 验证边的有效性 |
package com.alibaba.cloud.ai.graph.internal.edge;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static java.lang.String.format;
/**
* Represents an edge in a graph with a source ID and a target value.
*
* @param sourceId The ID of the source node.
* @param targets The targets value associated with the edge.
*/
public record Edge(String sourceId, List<EdgeValue> targets) {
public Edge(String sourceId, EdgeValue target) {
this(sourceId, List.of(target));
}
public Edge(String id) {
this(id, List.of());
}
public boolean isParallel() {
return targets.size() > 1;
}
public EdgeValue target() {
if (isParallel()) {
throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
}
return targets.get(0);
}
public boolean anyMatchByTargetId(String targetId) {
return targets().stream()
.anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
: v.value().mappings().containsValue(targetId)
);
}
public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
Function<String, EdgeValue> newTarget) {
var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
return new Edge(newSourceId.apply(sourceId), newTargets);
}
public void validate(StateGraph.Nodes nodes) throws GraphStateException {
if (!Objects.equals(sourceId(), START) && !nodes.anyMatchById(sourceId())) {
throw Errors.missingNodeReferencedByEdge.exception(sourceId());
}
if (isParallel()) { // check for duplicates targets
Set<String> duplicates = targets.stream()
.collect(Collectors.groupingBy(EdgeValue::id, Collectors.counting())) // Group
// by
// element
// and
// count
// occurrences
.entrySet()
.stream()
.filter(entry -> entry.getValue() > 1) // Filter elements with more than
// one occurrence
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
if (!duplicates.isEmpty()) {
throw Errors.duplicateEdgeTargetError.exception(sourceId(), duplicates);
}
}
for (EdgeValue target : targets) {
validate(target, nodes);
}
}
private void validate(EdgeValue target, StateGraph.Nodes nodes) throws GraphStateException {
if (target.id() != null) {
if (!Objects.equals(target.id(), StateGraph.END) && !nodes.anyMatchById(target.id())) {
throw Errors.missingNodeReferencedByEdge.exception(target.id());
}
}
else if (target.value() != null) {
for (String nodeId : target.value().mappings().values()) {
if (!Objects.equals(nodeId, StateGraph.END) && !nodes.anyMatchById(nodeId)) {
throw Errors.missingNodeInEdgeMapping.exception(sourceId(), nodeId);
}
}
}
else {
throw Errors.invalidEdgeTarget.exception(sourceId());
}
}
/**
* Checks if this edge is equal to another object.
* @param o the object to compare with
* @return true if this edge is equal to the specified object, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
Edge node = (Edge) o;
return Objects.equals(sourceId, node.sourceId);
}
/**
* Returns the hash code value for this edge.
* @return the hash code value for this edge
*/
@Override
public int hashCode() {
return Objects.hash(sourceId);
}
}
EdgeValue
图边值的核心类,用于定义边的目标节点或条件,支持条件边
| 字段名称 | 字段类型 | 描述 |
| id | String | 目标节点的唯一标识符,用于固定边 |
| value | EdgeCondition | 与边值关联的条件,用于条件边 |
| 方法名称 | 描述 | |
| 构造 | EdgeValue | - (String id):创建一个只有目标节点 ID 的 EdgeValue 实例,条件为 null - (EdgeCondition value):创建一个只有条件的EdgeValue实例,目标节点ID为null - (String id, EdgeCondition value):创建一个包含目标节点 ID 和条件的 EdgeValue 实例 |
| 更新方法 | withTargetIdsUpdated(Function |
package com.alibaba.cloud.ai.graph.internal.edge;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* @param id The unique identifier for the edge value.
* @param value The condition associated with the edge value.
*/
public record EdgeValue(String id, EdgeCondition value) {
public EdgeValue(String id) {
this(id, null);
}
public EdgeValue(EdgeCondition value) {
this(null, value);
}
EdgeValue withTargetIdsUpdated(Function<String, EdgeValue> target) {
if (id != null) {
return target.apply(id);
}
var newMappings = value.mappings().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> {
var v = target.apply(e.getValue());
return (v.id() != null) ? v.id() : e.getValue();
}));
return new EdgeValue(null, new EdgeCondition(value.action(), newMappings));
}
}
EdgeCondition
用于定义条件边的条和映射关系
| 字段名称 | 字段类型 | 描述 |
| action | AsyncCommandAction | 异步命令动作,用于执行条件判断逻辑 |
| mappings | Map |
package com.alibaba.cloud.ai.graph.internal.edge;
import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import java.util.Map;
import static java.lang.String.format;
/**
* Represents a condition associated with an edge in a graph.
*
* @param action The action to be performed asynchronously when the edge condition is met.
* @param mappings A map of string key-value pairs representing additional mappings for
* the edge condition.
*/
public record EdgeCondition(AsyncCommandAction action, Map<String, String> mappings) {
@Override
public String toString() {
return format("EdgeCondition[ %s, mapping=%s", action != null ? "action" : "null", mappings);
}
}
学习交流圈
你好,我是影子,曾先后在🐻、新能源、老铁就职,兼任Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,关注公众号后可获得个人微信,添加微信后备注“交流”入群。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取


136

被折叠的 条评论
为什么被折叠?



