Spring Ai Alibaba Graph源码解读系列—核心启动类

原文链接地址:Spring Ai Alibaba Graph源码解读系列—核心启动类

[!TIP]
取自 Spring Ai Alibaba 1.0.0.3 版本

核心类

OverAllState

图执行过程中的核心状态管理类,用于存储和管理图中各个节点间共享的数据

字段名称
字段类型
描述
data
Map

对外暴露的方法


方法名称
描述
构造方法

OverAllState
支持四种方式构造
- 无参: 默认构造函数,初始化空数据和策略映射,并注册默认输入键
- Map

内部静态类 HumanFeedback,处理和存储在工作流执行过程中来自人工反馈的信息

  • Map<String, Object> data:存储人工反馈的具体数据内容,可以包含任意键值对形式的数据
  • String nextNodeId:指定下一个要执行的节点 ID
package com.alibaba.cloud.ai.graph;

import org.springframework.util.CollectionUtils;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;

import static com.alibaba.cloud.ai.graph.utils.CollectionsUtils.entryOf;
import static java.util.Collections.unmodifiableMap;
import static java.util.Optional.ofNullable;

public final class OverAllState implements Serializable {

    /**
     * Internal map storing the actual state data. All get/set operations on state values
     * go through this map.
     */
    private final Map<String, Object> data;

    /**
     * Mapping of keys to their respective update strategies. Determines how values for
     * each key should be merged or updated.
     */
    private final Map<String, KeyStrategy> keyStrategies;

    /**
     * Indicates whether this state is being used to resume a previously interrupted
     * execution. If true, certain initialization steps may be skipped.
     */
    private Boolean resume;

    /**
     * Holds optional human feedback information provided during execution. May be null if
     * no feedback was given.
     */
    private HumanFeedback humanFeedback;

    /**
     * Optional message indicating that the execution was interrupted. If non-null,
     * indicates that the graph should halt or handle the interruption.
     */
    private String interruptMessage;

    /**
     * The default key used for standard input injection into the state. Typically used
     * when initializing the state with user or external input.
     */
    public static final String DEFAULTINPUTKEY = "input";

    /**
     * Reset.
     */
    public void reset() {
       this.data.clear();
    }

    /**
     * Snap shot optional.
     * @return the optional
     */
    public Optional<OverAllState> snapShot() {
       return Optional.of(new OverAllState(new HashMap<>(this.data), new HashMap<>(this.keyStrategies), this.resume));
    }

    /**
     * Instantiates a new Over all state.
     * @param resume the is resume
     */
    public OverAllState(boolean resume) {
       this.data = new HashMap<>();
       this.keyStrategies = new HashMap<>();
       this.resume = resume;
    }

    /**
     * Instantiates a new Over all state.
     * @param data the data
     */
    public OverAllState(Map<String, Object> data) {
       this.data = new HashMap<>(data);
       this.keyStrategies = new HashMap<>();
       this.resume = false;
    }

    /**
     * Instantiates a new Over all state.
     */
    public OverAllState() {
       this.data = new HashMap<>();
       this.keyStrategies = new HashMap<>();
       this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
       this.resume = false;
    }

    /**
     * Instantiates a new Over all state.
     * @param data the data
     * @param keyStrategies the key strategies
     * @param resume the resume
     */
    protected OverAllState(Map<String, Object> data, Map<String, KeyStrategy> keyStrategies, Boolean resume) {
       this.data = data;
       this.keyStrategies = keyStrategies;
       this.registerKeyAndStrategy(OverAllState.DEFAULTINPUTKEY, new ReplaceStrategy());
       this.resume = resume;
    }

    /**
     * Interrupt message string.
     * @return the string
     */
    public String interruptMessage() {
       return interruptMessage;
    }

    /**
     * Sets interrupt message.
     * @param interruptMessage the interrupt message
     */
    public void setInterruptMessage(String interruptMessage) {
       this.interruptMessage = interruptMessage;
    }

    /**
     * With human feedback.
     * @param humanFeedback the human feedback
     */
    public void withHumanFeedback(HumanFeedback humanFeedback) {
       this.humanFeedback = humanFeedback;
    }

    /**
     * Human feedback human feedback.
     * @return the human feedback
     */
    public HumanFeedback humanFeedback() {
       return this.humanFeedback;
    }

    /**
     * Copy with resume over all state.
     * @return the over all state
     */
    public OverAllState copyWithResume() {
       return new OverAllState(this.data, this.keyStrategies, true);
    }

    /**
     * With resume.
     */
    public void withResume() {
       this.resume = true;
    }

    /**
     * Without resume.
     */
    public void withoutResume() {
       this.resume = false;
    }

    /**
     * Is resume boolean.
     * @return the boolean
     */
    public boolean isResume() {
       return this.resume;
    }

    /**
     * Clears all data in the current state, leaving key strategies, resume flag, and
     * human feedback intact.
     */
    public void clear() {
       this.data.clear();
    }

    /**
     * Replaces the current state's contents with the provided state.
     * <p>
     * This method effectively copies all data, key strategies, resume flag, and human
     * feedback from the provided state to this state.
     * @param overAllState the state to copy from
     */
    public void cover(OverAllState overAllState) {
       this.keyStrategies.clear();
       this.keyStrategies.putAll(overAllState.keyStrategies());
       this.data.clear();
       this.data.putAll(overAllState.data());
       this.resume = overAllState.resume;
       this.humanFeedback = overAllState.humanFeedback;
    }

    /**
     * Inputs over all state.
     * @param input the input
     * @return the over all state
     */
    public OverAllState input(Map<String, Object> input) {
       if (input == null) {
          withResume();
          return this;
       }

       if (CollectionUtils.isEmpty(input)) {
          return this;
       }

       Map<String, KeyStrategy> keyStrategies = keyStrategies();
       input.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
          this.data.put(key, keyStrategies.get(key).apply(value(key, null), input.get(key)));
       });
       return this;
    }

    /**
     * Add key and strategy over all state.
     * @param key the key
     * @param strategy the strategy
     * @return the over all state
     */
    public OverAllState registerKeyAndStrategy(String key, KeyStrategy strategy) {
       this.keyStrategies.put(key, strategy);
       return this;
    }

    /**
     * Register key and strategy over all state.
     * @param keyStrategies the key strategies
     * @return the over all state
     */
    public OverAllState registerKeyAndStrategy(Map<String, KeyStrategy> keyStrategies) {
       this.keyStrategies.putAll(keyStrategies);
       return this;
    }

    /**
     * Is contain strategy boolean.
     * @param key the key
     * @return the boolean
     */
    public boolean containStrategy(String key) {
       return this.keyStrategies.containsKey(key);
    }

    /**
     * Update state map.
     * @param partialState the partial state
     * @return the map
     */
    public Map<String, Object> updateState(Map<String, Object> partialState) {
       Map<String, KeyStrategy> keyStrategies = keyStrategies();
       partialState.keySet().stream().filter(key -> keyStrategies.containsKey(key)).forEach(key -> {
          this.data.put(key, keyStrategies.get(key).apply(value(key, null), partialState.get(key)));
       });
       return data();
    }

    /**
     * Updates the internal state based on a schema-defined strategy.
     * <p>
     * This method first validates the input state, then updates the partial state
     * according to the provided key strategies. The updated state is formed by merging
     * the original state and the modified partial state, removing any null values in the
     * process. The resulting entries are then used to update the internal data map.
     * @param state the base state to update; must not be null
     * @param partialState the partial state containing updates; may be null or empty
     * @param keyStrategies the mapping of keys to update strategies; used to transform
     * values
     */
    public void updateStateBySchema(Map<String, Object> state, Map<String, Object> partialState,
          Map<String, KeyStrategy> keyStrategies) {
       updateState(updateState(state, partialState, keyStrategies));
    }

    /**
     * Key verify boolean.
     * @return the boolean
     */
    protected boolean keyVerify() {
       return hasCommonKey(this.data, getKeyStrategies());
    }

    private Map<?, ?> getKeyStrategies() {
       return this.keyStrategies;
    }

    private boolean hasCommonKey(Map<?, ?> map1, Map<?, ?> map2) {
       Set<?> keys1 = map1.keySet();
       for (Object key : map2.keySet()) {
          if (keys1.contains(key)) {
             return true;
          }
       }
       return false;
    }

    /**
     * Updates a state with the provided partial state. The merge function is used to
     * merge the current state value with the new value.
     * @param state the current state
     * @param partialState the partial state to update from
     * @return the updated state
     * @throws NullPointerException if state is null
     */
    public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState) {
       Objects.requireNonNull(state, "state cannot be null");
       if (partialState == null || partialState.isEmpty()) {
          return state;
       }

       return Stream.concat(state.entrySet().stream(), partialState.entrySet().stream())
          .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, OverAllState::mergeFunction));
    }

    /**
     * Update state map.
     * @param state the state
     * @param partialState the partial state
     * @param keyStrategies the key strategies
     * @return the map
     */
    public static Map<String, Object> updateState(Map<String, Object> state, Map<String, Object> partialState,
          Map<String, KeyStrategy> keyStrategies) {
       Objects.requireNonNull(state, "state cannot be null");
       if (partialState == null || partialState.isEmpty()) {
          return state;
       }

       Map<String, Object> updatedPartialState = updatePartialStateFromSchema(state, partialState, keyStrategies);

       return Stream.concat(state.entrySet().stream(), updatedPartialState.entrySet().stream())
          .collect(toMapRemovingNulls(Map.Entry::getKey, Map.Entry::getValue, (currentValue, newValue) -> newValue));
    }

    /**
     * Updates the partial state from a schema using channels.
     * @param state The current state as a map of key-value pairs.
     * @param partialState The partial state to be updated.
     * @param keyStrategies A map of channel names to their implementations.
     * @return An updated version of the partial state after applying the schema and
     * channels.
     */
    private static Map<String, Object> updatePartialStateFromSchema(Map<String, Object> state,
          Map<String, Object> partialState, Map<String, KeyStrategy> keyStrategies) {
       if (keyStrategies == null || keyStrategies.isEmpty()) {
          return partialState;
       }
       return partialState.entrySet().stream().map(entry -> {

          KeyStrategy channel = keyStrategies.get(entry.getKey());
          if (channel != null) {
             Object newValue = channel.apply(state.get(entry.getKey()), entry.getValue());
             return entryOf(entry.getKey(), newValue);
          }

          return entry;
       }).collect(toMapAllowingNulls(Map.Entry::getKey, Map.Entry::getValue));
    }

    private static <T, K, U> Collector<T, ?, Map<K, U>> toMapRemovingNulls(Function<? super T, ? extends K> keyMapper,
          Function<? super T, ? extends U> valueMapper, BinaryOperator<U> mergeFunction) {
       return Collector.of(HashMap::new, (map, element) -> {
          K key = keyMapper.apply(element);
          U value = valueMapper.apply(element);
          if (value == null) {
             map.remove(key);
          }
          else {
             map.merge(key, value, mergeFunction);
          }
       }, (map1, map2) -> {
          map2.forEach((key, value) -> {
             if (value != null) {
                map1.merge(key, value, mergeFunction);
             }
          });
          return map1;
       }, Collector.Characteristics.UNORDERED);
    }

    private static <T, K, U> Collector<T, ?, Map<K, U>> toMapAllowingNulls(Function<? super T, ? extends K> keyMapper,
          Function<? super T, ? extends U> valueMapper) {
       return Collector.of(HashMap::new,
             (map, element) -> map.put(keyMapper.apply(element), valueMapper.apply(element)), (map1, map2) -> {
                map1.putAll(map2);
                return map1;
             }, Collector.Characteristics.UNORDERED);
    }

    /**
     * Merges the current value with the new value using the appropriate merge function.
     * @param currentValue the current value
     * @param newValue the new value
     * @return the merged value
     */
    private static Object mergeFunction(Object currentValue, Object newValue) {
       return newValue;
    }

    /**
     * Key strategies map.
     * @return the map
     */
    public Map<String, KeyStrategy> keyStrategies() {
       return keyStrategies;
    }

    /**
     * Data map.
     * @return the map
     */
    public final Map<String, Object> data() {
       return unmodifiableMap(data);
    }

    /**
     * Value optional.
     * @param <T> the type parameter
     * @param key the key
     * @return the optional
     */
    public final <T> Optional<T> value(String key) {
       return ofNullable((T) data().get(key));
    }

    /**
     * Value optional.
     * @param <T> the type parameter
     * @param key the key
     * @param type the type
     * @return the optional
     */
    public final <T> Optional<T> value(String key, Class<T> type) {
       if (type != null) {
          return ofNullable(type.cast(data().get(key)));
       }
       return value(key);
    }

    /**
     * Value t.
     * @param <T> the type parameter
     * @param key the key
     * @param defaultValue the default value
     * @return the t
     */
    public final <T> T value(String key, T defaultValue) {
       return (T) value(key).orElse(defaultValue);
    }

    /**
     * The type Human feedback.
     */
    public static class HumanFeedback implements Serializable {

       private Map<String, Object> data;

       private String nextNodeId;

       private String currentNodeId;

       /**
        * Instantiates a new Human feedback.
        * @param data the data
        * @param nextNodeId the next node id
        */
       public HumanFeedback(Map<String, Object> data, String nextNodeId) {
          this.data = data;
          this.nextNodeId = nextNodeId;
       }

       /**
        * Data map.
        * @return the map
        */
       public Map<String, Object> data() {
          return data;
       }

       /**
        * Next node id string.
        * @return the string
        */
       public String nextNodeId() {
          return nextNodeId;
       }

       /**
        * Sets data.
        * @param data the data
        */
       public void setData(Map<String, Object> data) {
          this.data = data;
       }

       /**
        * Sets next node id.
        * @param nextNodeId the next node id
        */
       public void setNextNodeId(String nextNodeId) {
          this.nextNodeId = nextNodeId;
       }

    }

    @Override
    public String toString() {
       return "OverAllState{" + "data=" + data + ", resume=" + resume + ", humanFeedback=" + humanFeedback
             + ", interruptMessage='" + interruptMessage + '\'' + '}';
    }

}
RunnableConfig

运行配置类

字段名称
字段类型
描述
threadId
String
线程ID
checkPointId
String
检查点ID
nextNode
String
下一个要执行的节点ID
streamMode
CompiledGraph.StreamMode
编译图的流模式,详情可见CompiledGraph类说明
metadata
Map
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;

import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;

import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.ofNullable;

/**
 * A final class representing configuration for a runnable task. This class holds various
 * parameters such as thread ID, checkpoint ID, next node, and stream mode, providing
 * methods to modify these parameters safely without permanently altering the original
 * configuration.
 */
public final class RunnableConfig implements HasMetadata<RunnableConfig.Builder> {

    private final String threadId;

    private final String checkPointId;

    private final String nextNode;

    private final CompiledGraph.StreamMode streamMode;

    private final Map<String, Object> metadata;

    private final Map<String, Object> interruptedNodes;

    /**
     * Returns the stream mode of the compiled graph.
     * @return {@code StreamMode} representing the current stream mode.
     */
    public CompiledGraph.StreamMode streamMode() {
       return streamMode;
    }

    /**
     * Returns the thread ID as an {@link Optional}.
     * @return the thread ID wrapped in an {@code Optional}, or an empty {@code Optional}
     * if no thread ID is set.
     */
    public Optional<String> threadId() {
       return ofNullable(threadId);
    }

    /**
     * Returns the current {@code checkPointId} wrapped in an {@link Optional}.
     * @return an {@link Optional} containing the {@code checkPointId}, or
     * {@link Optional#empty()} if it is null.
     */
    public Optional<String> checkPointId() {
       return ofNullable(checkPointId);
    }

    /**
     * Returns an {@code Optional} describing the next node in the sequence, or an empty
     * {@code Optional} if there is no such node.
     * @return an {@code Optional} describing the next node, or an empty {@code Optional}
     */
    public Optional<String> nextNode() {
       return ofNullable(nextNode);
    }

    /**
     * Checks if a node is marked as interrupted in the metadata.
     * @param nodeId the ID of the node to check for interruption status
     * @return true if the node is marked as interrupted, false otherwise
     */
    public boolean isInterrupted(String nodeId) {
       return interruptData(HasMetadata.formatNodeId(nodeId)).map(value -> Boolean.TRUE.equals(value)).orElse(false);
    }

    /**
     * Marks a node as not interrupted by setting its value to false in the metadata.
     * @param nodeId the ID of the node to mark as not interrupted
     * @return a new {@code RunnableConfig} instance with the updated metadata
     */
    public void withNodeResumed(String nodeId) {
       String formattedNodeId = HasMetadata.formatNodeId(nodeId);
       interruptedNodes.put(formattedNodeId, false);
    }

    /**
     * Removes the interrupted marker for a specific node by removing its entry from the
     * metadata.
     * @param nodeId the ID of the node to remove the interrupted marker for
     * @return a new {@code RunnableConfig} instance with the updated metadata
     */
    public void removeInterrupted(String nodeId) {
       String formattedNodeId = HasMetadata.formatNodeId(nodeId);
       if (interruptedNodes == null || !interruptedNodes.containsKey(formattedNodeId)) {
          return; // No change needed if the marker doesn't exist
       }
       interruptedNodes.remove(formattedNodeId);
    }

    /**
     * Marks a node as interrupted by adding it to the metadata with a formatted key. The
     * node ID is formatted using {@link #formatNodeId(String)} and associated with a
     * value of {@code true} in the metadata map.
     * @param nodeId the ID of the node to mark as interrupted; must not be null
     * @return this {@code Builder} instance for method chaining
     * @throws NullPointerException if nodeId is null
     */
    public void markNodeAsInterrupted(String nodeId) {
       interruptedNodes.put(HasMetadata.formatNodeId(nodeId), true);
    }

    /**
     * Create a new RunnableConfig with the same attributes as this one but with a
     * different {@link CompiledGraph.StreamMode}.
     * @param streamMode the new stream mode
     * @return a new RunnableConfig with the updated stream mode
     */
    public RunnableConfig withStreamMode(CompiledGraph.StreamMode streamMode) {
       if (this.streamMode == streamMode) {
          return this;
       }

       return RunnableConfig.builder(this).streamMode(streamMode).build();
    }

    /**
     * Updates the checkpoint ID of the configuration.
     * @param checkPointId The new checkpoint ID to set.
     * @return A new instance of {@code RunnableConfig} with the updated checkpoint ID, or
     * the current instance if no change is needed.
     */
    public RunnableConfig withCheckPointId(String checkPointId) {
       if (Objects.equals(this.checkPointId, checkPointId)) {
          return this;
       }
       return RunnableConfig.builder(this).checkPointId(checkPointId).build();

    }

    /**
     * Retrieves interrupt data associated with the specified key.
     * @param key the key for which to retrieve interrupt data; may be null
     * @return an Optional containing the interrupt data if present, or an empty Optional
     * if the key is null or no data is found
     */
    public Optional<Object> interruptData(String key) {
       if (key == null) {
          return Optional.empty();
       }
       return ofNullable(interruptedNodes).map(m -> m.get(key));
    }

    /**
     * return metadata value for key
     * @param key given metadata key
     * @return metadata value for key if any
     */
    @Override
    public Optional<Object> metadata(String key) {
       if (key == null) {
          return Optional.empty();
       }
       return ofNullable(metadata).map(m -> m.get(key));
    }

    /**
     * Creates a new instance of the {@link Builder} class.
     * @return A new {@code Builder} object.
     */
    public static Builder builder() {
       return new Builder();
    }

    /**
     * Creates a new {@code Builder} instance with the specified {@link RunnableConfig}.
     * @param config The configuration for the {@code Builder}.
     * @return A new {@code Builder} instance.
     */
    public static Builder builder(RunnableConfig config) {
       return new Builder(config);
    }

    /**
     * A builder pattern class for constructing {@link RunnableConfig} objects. This class
     * provides a fluent interface to set various properties of a {@link RunnableConfig}
     * object and then build the final configuration.
     */
    public static class Builder extends HasMetadata.Builder<Builder> {

       private String threadId;

       private String checkPointId;

       private String nextNode;

       private CompiledGraph.StreamMode streamMode = CompiledGraph.StreamMode.VALUES;

       /**
        * Constructs a new instance of the {@link Builder} with default configuration
        * settings. Initializes a new {@link RunnableConfig} object for configuration
        * purposes.
        */
       Builder() {
       }

       /**
        * Initializes a new instance of the {@code Builder} class with the specified
        * {@link RunnableConfig}.
        * @param config The configuration to be used for initialization.
        */
       Builder(RunnableConfig config) {
          super(requireNonNull(config, "config cannot be null!").metadata);
          this.threadId = config.threadId;
          this.checkPointId = config.checkPointId;
          this.nextNode = config.nextNode;
          this.streamMode = config.streamMode;
       }

       /**
        * Sets the ID of the thread.
        * @param threadId the ID of the thread to set
        * @return a reference to this {@code Builder} object so that method calls can be
        * chained together
        */
       public Builder threadId(String threadId) {
          this.threadId = threadId;
          return this;
       }

       /**
        * Sets the checkpoint ID for the configuration.
        * @param {@code checkPointId} - the ID of the checkpoint to be set
        * @return {@literal this} - a reference to the current `Builder` instance
        */
       public Builder checkPointId(String checkPointId) {
          this.checkPointId = checkPointId;
          return this;
       }

       /**
        * Sets the next node in the configuration and returns this builder for method
        * chaining.
        * @param nextNode The next node to be set.
        * @return This builder instance, allowing for method chaining.
        */
       public Builder nextNode(String nextNode) {
          this.nextNode = nextNode;
          return this;
       }

       /**
        * Sets the stream mode of the configuration.
        * @param streamMode The {@link CompiledGraph.StreamMode} to set.
        * @return A reference to this builder for method chaining.
        */
       public Builder streamMode(CompiledGraph.StreamMode streamMode) {
          this.streamMode = streamMode;
          return this;
       }

       /**
        * Adds a custom {@link Executor} for a specific parallel node.
        * <p>
        * This allows you to control the execution of branches within a parallel node.
        * When a parallel node is executed, it will look for an executor in the
        * {@link RunnableConfig} metadata. If found, it will be used to run the parallel
        * branches concurrently.
        * @param nodeId the ID of the parallel node.
        * @param executor the {@link Executor} to use for the parallel node.
        * @return this {@code Builder} instance for method chaining.
        */
       public Builder addParallelNodeExecutor(String nodeId, Executor executor) {
          return addMetadata(ParallelNode.formatNodeId(nodeId), requireNonNull(executor, "executor cannot be null!"));
       }

       /**
        * Constructs and returns the configured {@code RunnableConfig} object.
        * @return the configured {@code RunnableConfig} object
        */
       public RunnableConfig build() {
          return new RunnableConfig(this);
       }

    }

    /**
     * Creates a new instance of {@code RunnableConfig} as a copy of the provided
     * {@code config}.
     * @param builder The configuration builder.
     */
    private RunnableConfig(Builder builder) {
       this.threadId = builder.threadId;
       this.checkPointId = builder.checkPointId;
       this.nextNode = builder.nextNode;
       this.streamMode = builder.streamMode;
       this.metadata = ofNullable(builder.metadata()).map(Map::copyOf).orElse(null);
       this.interruptedNodes = new ConcurrentHashMap<>();
    }

    @Override
    public String toString() {
       return format("RunnableConfig{ threadId=%s, checkPointId=%s, nextNode=%s, streamMode=%s }", threadId,
             checkPointId, nextNode, streamMode);
    }

}
StateGraph

用于表示和构建基于状态的图结构工作流,有如下功能

  • 图结构定义:提供构建有向图的 API,包括节点和边的定义
  • 状态管理:与 OverAllState 配合管理图执行过程中的状态
  • 工作流编排:支持定义复杂的工作流执行逻辑,包括条件边和子图
字段名称
字段类型
描述
START
String
图的起始节点标识常量("START")
END
String
图的结束节点标识常量("END")
ERROR
String
图的错误节点标识常量("ERROR")
NODEBEFORE
String
节点执行前钩子标识常量("NODEBEFORE")
NODEAFTER
String
节点执行后钩子标识常量("NODEAFTER")
nodes
Nodes
图中所有节点的容器
edges
Edges
图中所有边的容器
overAllStateFactory
OverAllStateFactory
创建整体状态实例的工厂(已废弃)
keyStrategyFactory
KeyStrategyFactory
供键策略的工厂
name
String
图的名称
stateSerializer
PlainTextStateSerializer
内部类,基于Jackson的状态序列化器

对外暴露的方法


方法名称
描述
构造

StateGraph
支持五种方式构造
- 无参
- KeyStrategyFactory keyStrategyFactory:带键策略工厂
- (String name, KeyStrategyFactory keyStrategyFactory):带名称和键策略工厂
- (KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer):带键策略工厂、序列化器
- (String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) :带键策略工厂、名称、序列化器
节点管理

addNode
- (String id, Node node):添加节点实例
- (String id, AsyncNodeAction action):添加普通节点
- (String id, AsyncNodeActionWithConfig actionWithConfig):添加带配置的节点
- (String id, AsyncCommandAction action, Map

静态内部类 Nodes,管理图中的节点集合

  • anyMatchById(String id):检查是否存在指定 ID 的节点
  • onlySubStateGraphNodes():获取所有子图节点
  • exceptSubStateGraphNodes():获取移除子图之外的所有节点

静态内部类 Edges,管理图中边集合

  • edgeBySourceId(String sourceId):根据源节点 ID 查找边
  • edgesByTargetId(String targetId):根据目标节点 ID 查找边
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;

import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeCondition;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.Node;
import com.alibaba.cloud.ai.graph.internal.node.SubCompiledGraphNode;
import com.alibaba.cloud.ai.graph.internal.node.SubStateGraphNode;
import com.alibaba.cloud.ai.graph.serializer.StateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.PlainTextStateSerializer;
import com.alibaba.cloud.ai.graph.serializer.plaintext.jackson.JacksonStateSerializer;
import com.alibaba.cloud.ai.graph.state.AgentStateFactory;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.LinkedHashSet;

/**
 * Represents a state graph with nodes and edges.
 */
public class StateGraph {

    /**
     * Constant representing the END of the graph.
     */
    public static final String END = "END";

    /**
     * Constant representing the START of the graph.
     */
    public static final String START = "START";

    /**
     * Constant representing the ERROR of the graph.
     */
    public static final String ERROR = "ERROR";

    /**
     * Constant representing the NODEBEFORE of the graph.
     */
    public static final String NODEBEFORE = "NODEBEFORE";

    /**
     * Constant representing the NODEAFTER of the graph.
     */
    public static final String NODEAFTER = "NODEAFTER";

    /**
     * Collection of nodes in the graph.
     */
    final Nodes nodes = new Nodes();

    /**
     * Collection of edges in the graph.
     */
    final Edges edges = new Edges();

    /**
     * Factory for creating overall state instances.
     */
    private OverAllStateFactory overAllStateFactory;

    /**
     * Factory for providing key strategies.
     */
    private KeyStrategyFactory keyStrategyFactory;

    /**
     * Name of the graph.
     */
    private String name;

    /**
     * Serializer for the state.
     */
    private final PlainTextStateSerializer stateSerializer;

    /**
     * Jackson-based serializer for state.
     */
    static class JacksonSerializer extends JacksonStateSerializer {

       /**
        * Instantiates a new Jackson serializer.
        */
       public JacksonSerializer() {
          super(OverAllState::new);
       }

       /**
        * Gets object mapper.
        * @return the object mapper
        */
       ObjectMapper getObjectMapper() {
          return objectMapper;
       }

    }

    /**
     * Constructs a StateGraph with the specified name, key strategy factory, and state
     * serializer.
     * @param name the name of the graph
     * @param keyStrategyFactory the factory for providing key strategies
     * @param stateSerializer the plain text state serializer to use
     */
    public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
       this.name = name;
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = stateSerializer;
    }

    public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = stateSerializer;
    }

    /**
     * Constructs a StateGraph with the given key strategy factory and name.
     * @param keyStrategyFactory the factory for providing key strategies
     * @param name the name of the graph
     */
    public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.name = name;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Constructs a StateGraph with the provided key strategy factory.
     * @param keyStrategyFactory the factory for providing key strategies
     */
    public StateGraph(KeyStrategyFactory keyStrategyFactory) {
       this.keyStrategyFactory = keyStrategyFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the specified name,
     * overall state factory, and state serializer.
     * @param name the name of the graph
     * @param overAllStateFactory the factory for creating overall state instances
     * @param plainTextStateSerializer the plain text state serializer to use
     */
    @Deprecated
    public StateGraph(String name, OverAllStateFactory overAllStateFactory,
          PlainTextStateSerializer plainTextStateSerializer) {
       this.name = name;
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = plainTextStateSerializer;
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the specified name and
     * overall state factory.
     * @param name the name of the graph
     * @param overAllStateFactory the factory for creating overall state instances
     */
    @Deprecated
    public StateGraph(String name, OverAllStateFactory overAllStateFactory) {
       this.name = name;
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the provided overall
     * state factory.
     * @param overAllStateFactory the factory for creating overall state instances
     */
    @Deprecated
    public StateGraph(OverAllStateFactory overAllStateFactory) {
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = new JacksonSerializer();
    }

    /**
     * Deprecated constructor that initializes a StateGraph with the provided overall
     * state factory and state serializer.
     * @param overAllStateFactory the factory for creating overall state instances
     * @param plainTextStateSerializer the plain text state serializer to use
     */
    @Deprecated
    public StateGraph(OverAllStateFactory overAllStateFactory, PlainTextStateSerializer plainTextStateSerializer) {
       this.overAllStateFactory = overAllStateFactory;
       this.stateSerializer = plainTextStateSerializer;
    }

    /**
     * Default constructor that initializes a StateGraph with a Gson-based state
     * serializer.
     */
    public StateGraph() {
       this.stateSerializer = new JacksonSerializer();
       this.keyStrategyFactory = HashMap::new;
    }

    /**
     * Gets the name of the graph.
     * @return the name
     */
    public String getName() {
       return name;
    }

    /**
     * Gets the state serializer used by this graph.
     * @return the state serializer
     */
    public StateSerializer<OverAllState> getStateSerializer() {
       return stateSerializer;
    }

    /**
     * Gets the state factory associated with this graph's state serializer.
     * @return the state factory
     */
    public final AgentStateFactory<OverAllState> getStateFactory() {
       return stateSerializer.stateFactory();
    }

    /**
     * Gets the overall state factory.
     * @return the overall state factory
     */
    @Deprecated
    public final OverAllStateFactory getOverAllStateFactory() {
       return overAllStateFactory;
    }

    /**
     * Gets the key strategy factory.
     * @return the key strategy factory
     */
    public final KeyStrategyFactory getKeyStrategyFactory() {
       return keyStrategyFactory;
    }

    /**
     * Adds a commandNode to the graph.
     * @param id the identifier of the node
     * @param action AsyncCommandAction action
     * @param mappings the mappings to be used for conditional edges
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncCommandAction action, Map<String, String> mappings)
          throws GraphStateException {

       return addNode(id, new CommandNode(id, action, mappings));
    }

    /**
     * Adds a node to the graph.
     * @param id the identifier of the node
     * @param action the asynchronous node action to be performed by the node
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncNodeAction action) throws GraphStateException {
       return addNode(id, AsyncNodeActionWithConfig.of(action));
    }

    /**
     * Adds a node to the graph with the specified action and configuration.
     * @param id the identifier of the node
     * @param actionWithConfig the action to be performed by the node
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, AsyncNodeActionWithConfig actionWithConfig) throws GraphStateException {
       Node node = new Node(id, (config) -> actionWithConfig);
       return addNode(id, node);
    }

    /**
     * Adds a node to the graph with the specified identifier and node instance.
     * @param id the identifier of the node
     * @param node the node to be added
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, Node node) throws GraphStateException {
       if (Objects.equals(node.id(), END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }
       if (!Objects.equals(node.id(), id)) {
          throw Errors.invalidNodeIdentifier.exception(node.id(), id);
       }

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds a subgraph to the state graph by creating a node with the specified
     * identifier. This implies that the subgraph shares the same state with the parent
     * graph.
     * @param id the identifier of the node representing the subgraph
     * @param subGraph the compiled subgraph to be added
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, CompiledGraph subGraph) throws GraphStateException {
       if (Objects.equals(id, END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }

       var node = new SubCompiledGraphNode(id, subGraph);

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds a subgraph to the state graph by creating a node with the specified
     * identifier. This implies that the subgraph will share the same state with the
     * parent graph and will be compiled when the parent is compiled.
     * @param id the identifier of the node representing the subgraph
     * @param subGraph the subgraph to be added; it will be compiled during compilation of
     * the parent
     * @return this state graph instance
     * @throws GraphStateException if the node identifier is invalid or the node already
     * exists
     */
    public StateGraph addNode(String id, StateGraph subGraph) throws GraphStateException {
       if (Objects.equals(id, END)) {
          throw Errors.invalidNodeIdentifier.exception(END);
       }

       subGraph.validateGraph();

       var node = new SubStateGraphNode(id, subGraph);

       if (nodes.elements.contains(node)) {
          throw Errors.duplicateNodeError.exception(id);
       }

       nodes.elements.add(node);
       return this;
    }

    /**
     * Adds an edge to the graph between the specified source and target nodes.
     * @param sourceId the identifier of the source node
     * @param targetId the identifier of the target node
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid or the edge already
     * exists
     */
    public StateGraph addEdge(String sourceId, String targetId) throws GraphStateException {
       if (Objects.equals(sourceId, END)) {
          throw Errors.invalidEdgeIdentifier.exception(END);
       }

       var newEdge = new Edge(sourceId, new EdgeValue(targetId));

       int index = edges.elements.indexOf(newEdge);
       if (index >= 0) {
          var newTargets = new ArrayList<>(edges.elements.get(index).targets());
          newTargets.add(newEdge.target());
          edges.elements.set(index, new Edge(sourceId, newTargets));
       }
       else {
          edges.elements.add(newEdge);
       }

       return this;
    }

    /**
     * Adds conditional edges to the graph based on the provided condition and mappings.
     * @param sourceId the identifier of the source node
     * @param condition the command action used to determine the target node
     * @param mappings the mappings of conditions to target nodes
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid, the mappings are
     * empty, or the edge already exists
     */
    public StateGraph addConditionalEdges(String sourceId, AsyncCommandAction condition, Map<String, String> mappings)
          throws GraphStateException {
       if (Objects.equals(sourceId, END)) {
          throw Errors.invalidEdgeIdentifier.exception(END);
       }
       if (mappings == null || mappings.isEmpty()) {
          throw Errors.edgeMappingIsEmpty.exception(sourceId);
       }

       var newEdge = new Edge(sourceId, new EdgeValue(new EdgeCondition(condition, mappings)));

       if (edges.elements.contains(newEdge)) {
          throw Errors.duplicateConditionalEdgeError.exception(sourceId);
       }
       else {
          edges.elements.add(newEdge);
       }
       return this;
    }

    /**
     * Adds conditional edges to the graph based on the provided edge action condition and
     * mappings.
     * @param sourceId the identifier of the source node
     * @param condition the edge action used to determine the target node
     * @param mappings the mappings of conditions to target nodes
     * @return this state graph instance
     * @throws GraphStateException if the edge identifier is invalid, the mappings are
     * empty, or the edge already exists
     */
    public StateGraph addConditionalEdges(String sourceId, AsyncEdgeAction condition, Map<String, String> mappings)
          throws GraphStateException {
       return addConditionalEdges(sourceId, AsyncCommandAction.of(condition), mappings);
    }

    /**
     * Validates the structure of the graph ensuring all connections are valid.
     * @throws GraphStateException if there are errors related to the graph state
     */
    void validateGraph() throws GraphStateException {
       var edgeStart = edges.edgeBySourceId(START).orElseThrow(Errors.missingEntryPoint::exception);

       edgeStart.validate(nodes);

       validateNode(nodes);

       for (Edge edge : edges.elements) {
          edge.validate(nodes);
       }
    }

    private void validateNode(Nodes nodes) throws GraphStateException {
       List<CommandNode> commandNodeList = nodes.elements.stream().filter(node -> {
          return node instanceof CommandNode commandNode;
       }).map(node -> (CommandNode) node).toList();
       for (CommandNode commandNode : commandNodeList) {
          for (String key : commandNode.getMappings().keySet()) {
             if (!nodes.anyMatchById(key)) {
                throw Errors.missingNodeInEdgeMapping.exception(commandNode.id(), key);
             }
          }
       }
    }

    /**
     * Compiles the state graph into a compiled graph using the provided configuration.
     * @param config the compile configuration
     * @return a compiled graph
     * @throws GraphStateException if there are errors related to the graph state
     */
    public CompiledGraph compile(CompileConfig config) throws GraphStateException {
       Objects.requireNonNull(config, "config cannot be null");

       validateGraph();

       return new CompiledGraph(this, config);
    }

    /**
     * Compiles the state graph into a compiled graph using a default configuration with
     * memory saver.
     * @return a compiled graph
     * @throws GraphStateException if there are errors related to the graph state
     */
    public CompiledGraph compile() throws GraphStateException {
       SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
       return compile(CompileConfig.builder().saverConfig(saverConfig).build());
    }

    /**
     * Generates a drawable graph representation of the state graph.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @param printConditionalEdges whether to include conditional edges in the output
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
       String content = type.generator.generate(nodes, edges, title, printConditionalEdges);

       return new GraphRepresentation(type, content);
    }

    /**
     * Generates a drawable graph representation of the state graph with conditional edges
     * included.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
       String content = type.generator.generate(nodes, edges, title, true);

       return new GraphRepresentation(type, content);
    }

    /**
     * Generates a drawable graph representation of the state graph using the graph's name
     * as title.
     * @param type the type of graph representation to generate
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
       String content = type.generator.generate(nodes, edges, name, true);

       return new GraphRepresentation(type, content);
    }

    /**
     * Container for nodes in the graph.
     */
    public static class Nodes {

       /**
        * The collection of nodes.
        */
       public final Set<Node> elements;

       /**
        * Instantiates a new Nodes container with the provided elements.
        * @param elements the elements to initialize
        */
       public Nodes(Collection<Node> elements) {
          this.elements = new LinkedHashSet<>(elements);
       }

       /**
        * Instantiates a new empty Nodes container.
        */
       public Nodes() {
          this.elements = new LinkedHashSet<>();
       }

       /**
        * Checks if any node matches the given identifier.
        * @param id the identifier to match
        * @return true if a matching node is found, false otherwise
        */
       public boolean anyMatchById(String id) {
          return elements.stream().anyMatch(n -> Objects.equals(n.id(), id));
       }

       /**
        * Returns a list of sub-state graph nodes.
        * @return a list of sub-state graph nodes
        */
       public List<SubStateGraphNode> onlySubStateGraphNodes() {
          return elements.stream()
             .filter(n -> n instanceof SubStateGraphNode)
             .map(n -> (SubStateGraphNode) n)
             .toList();
       }

       /**
        * Returns a list of nodes excluding sub-state graph nodes.
        * @return a list of nodes excluding sub-state graph nodes
        */
       public List<Node> exceptSubStateGraphNodes() {
          return elements.stream().filter(n -> !(n instanceof SubStateGraphNode)).toList();
       }

    }

    /**
     * Container for edges in the graph.
     */
    public static class Edges {

       /**
        * The collection of edges.
        */
       public final List<Edge> elements;

       /**
        * Instantiates a new Edges container with the provided elements.
        * @param elements the elements to initialize
        */
       public Edges(Collection<Edge> elements) {
          this.elements = new LinkedList<>(elements);
       }

       /**
        * Instantiates a new empty Edges container.
        */
       public Edges() {
          this.elements = new LinkedList<>();
       }

       /**
        * Retrieves the first edge matching the specified source identifier.
        * @param sourceId the source identifier to match
        * @return an optional containing the matched edge, or empty if none found
        */
       public Optional<Edge> edgeBySourceId(String sourceId) {
          return elements.stream().filter(e -> Objects.equals(e.sourceId(), sourceId)).findFirst();
       }

       /**
        * Retrieves a list of edges targeting the specified node identifier.
        * @param targetId the target identifier to match
        * @return a list of edges targeting the specified identifier
        */
       public List<Edge> edgesByTargetId(String targetId) {
          return elements.stream().filter(e -> e.anyMatchByTargetId(targetId)).toList();
       }

    }

}
CompileConfig

主要用于配置图的编译过程,包括检查点保存器、中断设置、生命周期监听器等

字段名称
字段类型
描述
saverConfig
SaverConfig
保存器配置,用于管理检查点保存器,替代旧的checkpointerSaver字段
lifecycleListeners
Deque
图生命周期监听器队列,用于监听节点执行事件
interruptsBefore
Set
在指定节点之前发生的中断点集合
interruptsAfter
Set
在指定节点之后发生的中断点集合
releaseThread
boolean
线程释放标志,指示是否在执行期间释放线程
observationRegistry
ObservationRegistry
观察注册表,用于监控和追踪

对外暴露方法

方法
描述
releaseThread
返回线程释放标志的当前状态
lifecycleListeners
获取不可变的节点生命周期监听器列表
observationRegistry
获取用于监控和追踪的观察注册表
interruptsBefore
返回在指定节点之前发生的中断点集合
interruptsAfter
返回在指定节点之后发生的中断点集合
checkpointSaver
丛保存期配置中检索默认检查点保存器
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import io.micrometer.observation.ObservationRegistry;

import java.util.Collection;
import java.util.Deque;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;

import static com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant.MEMORY;
import static java.util.Optional.ofNullable;

/**
 * class is a configuration container for defining compile settings and behaviors. It
 * includes various fields and methods to manage checkpoint savers and interrupts,
 * providing both deprecated and current accessors.
 */
public class CompileConfig {

    private SaverConfig saverConfig = new SaverConfig().register(MEMORY, new MemorySaver());

    private Deque<GraphLifecycleListener> lifecycleListeners = new LinkedBlockingDeque<>(25);

    // private BaseCheckpointSaver checkpointSaver; // replaced with SaverConfig
    private Set<String> interruptsBefore = Set.of();

    private Set<String> interruptsAfter = Set.of();

    private boolean releaseThread = false;

    private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

    /**
     * Returns the current state of the thread release flag.
     *
     * @see BaseCheckpointSaver#release(RunnableConfig)
     * @return true if the thread has been released, false otherwise
     */
    public boolean releaseThread() {
       return releaseThread;
    }

    /**
     * Gets an unmodifiable list of node lifecycle listeners.
     * @return The list of lifecycle listeners.
     */
    public Queue<GraphLifecycleListener> lifecycleListeners() {
       return lifecycleListeners;
    }

    /**
     * Gets observation registry for monitoring and tracing.
     * @return The observation registry instance.
     */
    public ObservationRegistry observationRegistry() {
       return observationRegistry;
    }

    /**
     * Returns the array of interrupts that will occur before the specified node
     * (deprecated).
     * @return An array of interruptible nodes.
     * @deprecated Use {@link #interruptsBefore()} instead for better immutability and
     * type safety.
     */
    @Deprecated
    public String[] getInterruptBefore() {
       return interruptsBefore.toArray(new String[0]);
    }

    /**
     * Returns the array of interrupts that will occur after the specified node
     * (deprecated).
     * @return An array of interruptible nodes.
     * @deprecated Use {@link #interruptsAfter()} instead for better immutability and type
     * safety.
     */
    @Deprecated
    public String[] getInterruptAfter() {
       return interruptsAfter.toArray(new String[0]);
    }

    /**
     * Returns the set of interrupts that will occur before the specified node.
     * @return An unmodifiable set of interruptible nodes.
     */
    public Set<String> interruptsBefore() {
       return interruptsBefore;
    }

    /**
     * Returns the set of interrupts that will occur after the specified node.
     * @return An unmodifiable set of interruptible nodes.
     */
    public Set<String> interruptsAfter() {
       return interruptsAfter;
    }

    /**
     * Retrieves a checkpoint saver based on the specified type from the saver
     * configuration.
     * @param type The type of the checkpoint saver to retrieve.
     * @return An Optional containing the checkpoint saver if available; otherwise, empty.
     */
    public Optional<BaseCheckpointSaver> checkpointSaver(String type) {
       return ofNullable(saverConfig.get(type));
    }

    /**
     * Retrieves the default checkpoint saver from the saver configuration.
     * @return An Optional containing the default checkpoint saver if available;
     * otherwise, empty.
     */
    public Optional<BaseCheckpointSaver> checkpointSaver() {
       return ofNullable(saverConfig.get());
    }

    /**
     * Returns a new instance of the builder with default configuration settings.
     * @return A new Builder instance.
     */
    public static Builder builder() {
       return new Builder(new CompileConfig());
    }

    /**
     * Returns a new instance of the builder initialized with the provided configuration.
     * @param config The compile configuration to use as a base.
     * @return A new Builder instance initialized with the given configuration.
     */
    public static Builder builder(CompileConfig config) {
       return new Builder(config);
    }

    /**
     * Builder class for creating instances of CompileConfig. It allows setting various
     * options such as savers, interrupts, and lifecycle listeners in a fluent manner.
     */
    public static class Builder {

       private final CompileConfig config;

       /**
        * Initializes the builder with the provided compile configuration.
        * @param config The base configuration to start from.
        */
       protected Builder(CompileConfig config) {
          this.config = new CompileConfig(config);
       }

       /**
        * Sets whether the thread should be released during execution.
        * @param releaseThread Flag indicating whether to release the thread.
        * @see BaseCheckpointSaver#release(RunnableConfig)
        * @return This builder instance for method chaining.
        */
       public Builder releaseThread(boolean releaseThread) {
          this.config.releaseThread = releaseThread;
          return this;
       }

       /**
        * Sets the observation registry for monitoring and tracing.
        * @param observationRegistry The ObservationRegistry to use.
        * @return This builder instance for method chaining.
        */
       public Builder observationRegistry(ObservationRegistry observationRegistry) {
          this.config.observationRegistry = observationRegistry;
          return this;
       }

       /**
        * Sets the saver configuration for checkpoints.
        * @param saverConfig The SaverConfig to use.
        * @return This builder instance for method chaining.
        */
       public Builder saverConfig(SaverConfig saverConfig) {
          this.config.saverConfig = saverConfig;
          return this;
       }

       /**
        * Sets individual interrupt points that trigger before node execution using
        * varargs.
        * @param interruptBefore One or more strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptBefore(String... interruptBefore) {
          this.config.interruptsBefore = Set.of(interruptBefore);
          return this;
       }

       /**
        * Sets individual interrupt points that trigger after node execution using
        * varargs.
        * @param interruptAfter One or more strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptAfter(String... interruptAfter) {
          this.config.interruptsAfter = Set.of(interruptAfter);
          return this;
       }

       /**
        * Sets multiple interrupt points that trigger before node execution from a
        * collection.
        * @param interruptsBefore Collection of strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptsBefore(Collection<String> interruptsBefore) {
          this.config.interruptsBefore = interruptsBefore.stream().collect(Collectors.toUnmodifiableSet());
          return this;
       }

       /**
        * Sets multiple interrupt points that trigger after node execution from a
        * collection.
        * @param interruptsAfter Collection of strings representing interrupt points.
        * @return This builder instance for method chaining.
        */
       public Builder interruptsAfter(Collection<String> interruptsAfter) {
          this.config.interruptsAfter = interruptsAfter.stream().collect(Collectors.toUnmodifiableSet());
          return this;
       }

       /**
        * Adds a lifecycle listener to monitor node execution events.
        * @param listener The NodeLifecycleListener to add.
        * @return This builder instance for method chaining.
        */
       public Builder withLifecycleListener(GraphLifecycleListener listener) {
          this.config.lifecycleListeners.offer(listener);
          return this;
       }

       /**
        * Finalizes the configuration and returns the compiled instance.
        * @return The configured CompileConfig object.
        */
       public CompileConfig build() {
          return config;
       }

    }

    /**
     * Default constructor used internally to create a new configuration with default
     * settings. Made private to ensure all instances are created through the builder
     * pattern.
     */
    private CompileConfig() {
    }

    /**
     * Copy constructor to create a new instance based on an existing configuration.
     * @param config The configuration to copy.
     */
    private CompileConfig(CompileConfig config) {
       this.saverConfig = config.saverConfig;
       this.interruptsBefore = config.interruptsBefore;
       this.interruptsAfter = config.interruptsAfter;
       this.releaseThread = config.releaseThread;
       this.lifecycleListeners = config.lifecycleListeners;
       this.observationRegistry = config.observationRegistry;
    }

}
CompiledGraph

图计算框架中的核心组件,代表一个已编译的图结

字段名称
字段类型
描述
stateGraph
StateGraph
关联原始状态图
keyStrategyMap
Map

对外暴露的方法


方法名称
描述
图执行方法
stream
- 无参:创建默认的流式输出
- (Map

StreamMode 枚举类

  • VALUES:值流模式
  • SNAPSHOTS:快照流模式

AsyncNodeGenerator 内部类:负责图的异步执行和流式输出处理

核心字段说明:

  • Cursor cursor:游标对象,用于跟踪当前和下一个节点 ID

    • String currentNodeId:当前节点 ID
    • String nextNodeId:下一个节点 ID
    • String resumeFrom:恢复执行的节点 ID
  • int iteration:当前迭代次数

  • RunnableConfig config:运行时配置

  • boolean returnFromEmbed:标记是否从嵌入生成器返回

  • Map<String, Object> currentState:当前状态的 Map 表示

  • OverAllState overAllState:封装的 OverAllState 对象,提供更丰富的状态操作

核心方法
描述
next
负责执行图的下一步操作
evaluateAction
评估并执行节点操作
nextNodeId
根据当前节点和状态确认下一个节点
getEmbedGenerator
从部分状态中提取嵌入式生成器
processGeneratorOutput
处理生成器输出数据
package com.alibaba.cloud.ai.graph;

import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.Command;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.checkpoint.BaseCheckpointSaver;
import com.alibaba.cloud.ai.graph.checkpoint.Checkpoint;
import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.exception.RunnableErrors;
import com.alibaba.cloud.ai.graph.internal.edge.Edge;
import com.alibaba.cloud.ai.graph.internal.edge.EdgeValue;
import com.alibaba.cloud.ai.graph.internal.node.CommandNode;
import com.alibaba.cloud.ai.graph.internal.node.ParallelNode;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import com.alibaba.cloud.ai.graph.streaming.AsyncGeneratorUtils;
import com.alibaba.cloud.ai.graph.utils.LifeListenerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.alibaba.cloud.ai.graph.StateGraph.*;
import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import static java.util.stream.Collectors.toList;

/**
 * The type Compiled graph.
 */
public class CompiledGraph {

    private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);

    /**
     * The enum Stream mode.
     */
    public enum StreamMode {

       /**
        * Values stream mode.
        */
       VALUES,
       /**
        * Snapshots stream mode.
        */
       SNAPSHOTS

    }

    /**
     * The State graph.
     */
    public final StateGraph stateGraph;

    private final Map<String, KeyStrategy> keyStrategyMap;

    /**
     * The Nodes.
     */
    final Map<String, AsyncNodeActionWithConfig> nodes = new LinkedHashMap<>();

    /**
     * The Edges.
     */
    final Map<String, EdgeValue> edges = new LinkedHashMap<>();

    private final ProcessedNodesEdgesAndConfig processedData;

    private int maxIterations = 25;

    /**
     * The Compile config.
     */
    public final CompileConfig compileConfig;

    private static String INTERRUPTAFTER = "INTERRUPTED";

    /**
     * Constructs a CompiledGraph with the given StateGraph.
     * @param stateGraph the StateGraph to be used in this CompiledGraph
     * @param compileConfig the compile config
     * @throws GraphStateException the graph state exception
     */
    protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
       this.stateGraph = stateGraph;
       this.keyStrategyMap = Objects.isNull(stateGraph.getOverAllStateFactory())
             ? stateGraph.getKeyStrategyFactory()
                .apply()
                .entrySet()
                .stream()
                .map(e -> Map.entry(e.getKey(), e.getValue()))
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
             : stateGraph.getOverAllStateFactory().create().keyStrategies();

       this.processedData = ProcessedNodesEdgesAndConfig.process(stateGraph, compileConfig);

       // CHECK INTERRUPTIONS
       for (String interruption : processedData.interruptsBefore()) {
          if (!processedData.nodes().anyMatchById(interruption)) {
             throw Errors.interruptionNodeNotExist.exception(interruption);
          }
       }
       for (String interruption : processedData.interruptsAfter()) {
          if (!processedData.nodes().anyMatchById(interruption)) {
             throw Errors.interruptionNodeNotExist.exception(interruption);
          }
       }

       // RE-CREATE THE EVENTUALLY UPDATED COMPILE CONFIG
       this.compileConfig = CompileConfig.builder(compileConfig)
          .interruptsBefore(processedData.interruptsBefore())
          .interruptsAfter(processedData.interruptsAfter())
          .build();

       // EVALUATES NODES
       for (var n : processedData.nodes().elements) {
          var factory = n.actionFactory();
          Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()));
          nodes.put(n.id(), factory.apply(compileConfig));
       }

       // EVALUATE EDGES
       for (var e : processedData.edges().elements) {
          var targets = e.targets();
          if (targets.size() == 1) {
             edges.put(e.sourceId(), targets.get(0));
          }
          else {
             Supplier<Stream<EdgeValue>> parallelNodeStream = () -> targets.stream()
                .filter(target -> nodes.containsKey(target.id()));

             var parallelNodeEdges = parallelNodeStream.get()
                .map(target -> new Edge(target.id()))
                .filter(ee -> processedData.edges().elements.contains(ee))
                .map(ee -> processedData.edges().elements.indexOf(ee))
                .map(index -> processedData.edges().elements.get(index))
                .toList();

             var parallelNodeTargets = parallelNodeEdges.stream()
                .map(ee -> ee.target().id())
                .collect(Collectors.toSet());

             if (parallelNodeTargets.size() > 1) {

                var conditionalEdges = parallelNodeEdges.stream()
                   .filter(ee -> ee.target().value() != null)
                   .toList();
                if (!conditionalEdges.isEmpty()) {
                   throw Errors.unsupportedConditionalEdgeOnParallelNode.exception(e.sourceId(),
                         conditionalEdges.stream().map(Edge::sourceId).toList());
                }
                throw Errors.illegalMultipleTargetsOnParallelNode.exception(e.sourceId(), parallelNodeTargets);
             }

             var actions = parallelNodeStream.get()
                // .map( target -> nodes.remove(target.id()) )
                .map(target -> nodes.get(target.id()))
                .toList();

             var parallelNode = new ParallelNode(e.sourceId(), actions, keyStrategyMap, compileConfig);

             nodes.put(parallelNode.id(), parallelNode.actionFactory().apply(compileConfig));

             edges.put(e.sourceId(), new EdgeValue(parallelNode.id()));

             edges.put(parallelNode.id(), new EdgeValue(parallelNodeTargets.iterator().next()));

          }

       }
    }

    public Collection<StateSnapshot> getStateHistory(RunnableConfig config) {
       BaseCheckpointSaver saver = compileConfig.checkpointSaver()
          .orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));

       return saver.list(config)
          .stream()
          .map(checkpoint -> StateSnapshot.of(keyStrategyMap, checkpoint, config, stateGraph.getStateFactory()))
          .collect(toList());
    }

    /**
     * Same of {@link #stateOf(RunnableConfig)} but throws an IllegalStateException if
     * checkpoint is not found.
     * @param config the RunnableConfig
     * @return the StateSnapshot of the given RunnableConfig
     * @throws IllegalStateException if the saver is not defined, or no checkpoint is
     * found
     */
    public StateSnapshot getState(RunnableConfig config) {
       return stateOf(config).orElseThrow(() -> (new IllegalStateException("Missing Checkpoint!")));
    }

    /**
     * Get the StateSnapshot of the given RunnableConfig.
     * @param config the RunnableConfig
     * @return an Optional of StateSnapshot of the given RunnableConfig
     * @throws IllegalStateException if the saver is not defined
     */
    public Optional<StateSnapshot> stateOf(RunnableConfig config) {
       BaseCheckpointSaver saver = compileConfig.checkpointSaver()
          .orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));

       return saver.get(config)
          .map(checkpoint -> StateSnapshot.of(keyStrategyMap, checkpoint, config, stateGraph.getStateFactory()));

    }

    /**
     * Update the state of the graph with the given values. If asNode is given, it will be
     * used to determine the next node to run. If not given, the next node will be
     * determined by the state graph.
     * @param config the RunnableConfig containing the graph state
     * @param values the values to be updated
     * @param asNode the node id to be used for the next node. can be null
     * @return the updated RunnableConfig
     * @throws Exception when something goes wrong
     */
    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values, String asNode)
          throws Exception {
       BaseCheckpointSaver saver = compileConfig.checkpointSaver()
          .orElseThrow(() -> (new IllegalStateException("Missing CheckpointSaver!")));

       // merge values with checkpoint values
       Checkpoint branchCheckpoint = saver.get(config)
          .map(Checkpoint::new)
          .map(cp -> cp.updateState(values, keyStrategyMap))
          .orElseThrow(() -> (new IllegalStateException("Missing Checkpoint!")));

       String nextNodeId = null;
       if (asNode != null) {
          var nextNodeCommand = nextNodeId(asNode, branchCheckpoint.getState(), config);

          nextNodeId = nextNodeCommand.gotoNode();
          branchCheckpoint = branchCheckpoint.updateState(nextNodeCommand.update(), keyStrategyMap);

       }
       // update checkpoint in saver
       RunnableConfig newConfig = saver.put(config, branchCheckpoint);

       return RunnableConfig.builder(newConfig).checkPointId(branchCheckpoint.getId()).nextNode(nextNodeId).build();
    }

    /***
     * Update the state of the graph with the given values.
     * @param config the RunnableConfig containing the graph state
     * @param values the values to be updated
     * @return the updated RunnableConfig
     * @throws Exception when something goes wrong
     */
    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values) throws Exception {
       return updateState(config, values, null);
    }

    /**
     * Sets the maximum number of iterations for the graph execution.
     * @param maxIterations the maximum number of iterations
     * @throws IllegalArgumentException if maxIterations is less than or equal to 0
     */
    public void setMaxIterations(int maxIterations) {
       if (maxIterations <= 0) {
          throw new IllegalArgumentException("maxIterations must be > 0!");
       }
       this.maxIterations = maxIterations;
    }

    private Command nextNodeId(EdgeValue route, Map<String, Object> state, String nodeId, RunnableConfig config)
          throws Exception {

       if (route == null) {
          throw RunnableErrors.missingEdge.exception(nodeId);
       }
       if (route.id() != null) {
          return new Command(route.id(), state);
       }
       if (route.value() != null) {
          OverAllState derefState = stateGraph.getStateFactory().apply(state);

          var command = route.value().action().apply(derefState, config).get();

          var newRoute = command.gotoNode();

          String result = route.value().mappings().get(newRoute);
          if (result == null) {
             throw RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
          }

          var currentState = OverAllState.updateState(state, command.update(), keyStrategyMap);

          return new Command(result, currentState);
       }
       throw RunnableErrors.executionError.exception(format("invalid edge value for nodeId: [%s] !", nodeId));
    }

    /**
     * Determines the next node ID based on the current node ID and state.
     * @param nodeId the current node ID
     * @param state the current state
     * @return the next node command
     * @throws Exception if there is an error determining the next node ID
     */
    private Command nextNodeId(String nodeId, Map<String, Object> state, RunnableConfig config) throws Exception {
       return nextNodeId(edges.get(nodeId), state, nodeId, config);

    }

    private Command getEntryPoint(Map<String, Object> state, RunnableConfig config) throws Exception {
       var entryPoint = this.edges.get(START);
       return nextNodeId(entryPoint, state, "entryPoint", config);
    }

    private boolean shouldInterruptBefore(String nodeId, String previousNodeId) {
       if (previousNodeId == null) { // FIX RESUME ERROR
          return false;
       }
       return compileConfig.interruptsBefore().contains(nodeId);
    }

    private boolean shouldInterruptAfter(String nodeId, String previousNodeId) {
       if (nodeId == null || Objects.equals(nodeId, previousNodeId)) { // FIX RESUME
                                                       // ERROR
          return false;
       }
       return (compileConfig.interruptBeforeEdge() && Objects.equals(nodeId, INTERRUPTAFTER))
             || compileConfig.interruptsAfter().contains(nodeId);
    }

    private Optional<Checkpoint> addCheckpoint(RunnableConfig config, String nodeId, Map<String, Object> state,
          String nextNodeId) throws Exception {
       if (compileConfig.checkpointSaver().isPresent()) {
          var cp = Checkpoint.builder().nodeId(nodeId).state(cloneState(state)).nextNodeId(nextNodeId).build();
          compileConfig.checkpointSaver().get().put(config, cp);
          return Optional.of(cp);
       }
       return Optional.empty();

    }

    /**
     * Gets initial state.
     * @param inputs the inputs
     * @param config the config
     * @return the initial state
     */
    Map<String, Object> getInitialState(Map<String, Object> inputs, RunnableConfig config) {

       return compileConfig.checkpointSaver()
          .flatMap(saver -> saver.get(config))
          .map(cp -> OverAllState.updateState(cp.getState(), inputs, keyStrategyMap))
          .orElseGet(() -> OverAllState.updateState(new HashMap<>(), inputs, keyStrategyMap));
    }

    /**
     * Clone state over all state.
     * @param data the data
     * @return the over all state
     */
    OverAllState cloneState(Map<String, Object> data) throws IOException, ClassNotFoundException {
       return stateGraph.getStateSerializer().cloneObject(data);
    }

    /**
     * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
     * @param inputs the input map
     * @param config the invoke configuration
     * @return an AsyncGenerator stream of NodeOutput
     */
    public AsyncGenerator<NodeOutput> stream(Map<String, Object> inputs, RunnableConfig config)
          throws GraphRunnerException {
       Objects.requireNonNull(config, "config cannot be null");
       final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(stateCreate(inputs), config);

       return new AsyncGenerator.WithEmbed<>(generator);
    }

    /**
     * Stream async generator.
     * @param overAllState the over all state
     * @param config the config
     * @return the async generator
     */
    public AsyncGenerator<NodeOutput> streamFromInitialNode(OverAllState overAllState, RunnableConfig config)
          throws GraphRunnerException {
       Objects.requireNonNull(config, "config cannot be null");
       final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(overAllState, config);

       return new AsyncGenerator.WithEmbed<>(generator);
    }

    /**
     * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
     * @param inputs the input map
     * @return an AsyncGenerator stream of NodeOutput
     */
    public AsyncGenerator<NodeOutput> stream(Map<String, Object> inputs) throws GraphRunnerException {
       return this.streamFromInitialNode(stateCreate(inputs), RunnableConfig.builder().build());
    }

    /**
     * Stream async generator.
     * @return the async generator
     */
    public AsyncGenerator<NodeOutput> stream() throws GraphRunnerException {
       return this.stream(Map.of(), RunnableConfig.builder().build());
    }

    /**
     * Invokes the graph execution with the provided inputs and returns the final state.
     * @param inputs the input map
     * @param config the invoke configuration
     * @return an Optional containing the final state if present, otherwise an empty
     * Optional
     */
    public Optional<OverAllState> invoke(Map<String, Object> inputs, RunnableConfig config)
          throws GraphRunnerException {
       return stream(inputs, config).stream().reduce((a, b) -> b).map(NodeOutput::state);
    }

    /**
     * Invoke optional.
     * @param overAllState the over all state
     * @param config the config
     * @return the optional
     */
    public Optional<OverAllState> invoke(OverAllState overAllState, RunnableConfig config) throws GraphRunnerException {
       return streamFromInitialNode(overAllState, config).stream().reduce((a, b) -> b).map(NodeOutput::state);
    }

    /**
     * Invokes the graph execution with the provided inputs and returns the final state.
     * @param inputs the input map
     * @return an Optional containing the final state if present, otherwise an empty
     * Optional
     */
    public Optional<OverAllState> invoke(Map<String, Object> inputs) throws GraphRunnerException {
       return this.invoke(stateCreate(inputs), RunnableConfig.builder().build());
    }

    private OverAllState stateCreate(Map<String, Object> inputs) {
       // Creates a new OverAllState instance based on the presence of an
       // OverAllStateFactory in the stateGraph.
       // If no factory is present, constructs a new state using key strategies from
       // the
       // graph and provided input data.
       // If a factory exists, uses it to create the state and applies the input data.
       return Objects.isNull(stateGraph.getOverAllStateFactory()) ? OverAllStateBuilder.builder()
          .withKeyStrategies(stateGraph.getKeyStrategyFactory().apply())
          .withData(inputs)
          .build() : stateGraph.getOverAllStateFactory().create().input(inputs);
    }

    /**
     * Experimental API
     * @param feedback the feedback
     * @param config the config
     * @return the optional
     */
    public Optional<OverAllState> resume(OverAllState.HumanFeedback feedback, RunnableConfig config)
          throws GraphRunnerException {
       StateSnapshot stateSnapshot = this.getState(config);
       OverAllState resumeState = stateCreate(stateSnapshot.state().data());
       resumeState.withResume();
       resumeState.withHumanFeedback(feedback);

       return this.invoke(resumeState, config);
    }

    /**
     * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
     * @param inputs the input map
     * @param config the invoke configuration
     * @return an AsyncGenerator stream of NodeOutput
     */
    public AsyncGenerator<NodeOutput> streamSnapshots(Map<String, Object> inputs, RunnableConfig config)
          throws GraphRunnerException {
       Objects.requireNonNull(config, "config cannot be null");

       final AsyncNodeGenerator<NodeOutput> generator = new AsyncNodeGenerator<>(stateCreate(inputs),
             config.withStreamMode(StreamMode.SNAPSHOTS));

       return new AsyncGenerator.WithEmbed<>(generator);
    }

    /**
     * Generates a drawable graph representation of the state graph.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @param printConditionalEdges whether to print conditional edges
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {

       String content = type.generator.generate(processedData.nodes(), processedData.edges(), title,
             printConditionalEdges);

       return new GraphRepresentation(type, content);
    }

    /**
     * Get the last StateSnapshot of the given RunnableConfig.
     * @param config - the RunnableConfig
     * @return the last StateSnapshot of the given RunnableConfig if any
     */
    Optional<StateSnapshot> lastStateOf(RunnableConfig config) {
       return getStateHistory(config).stream().findFirst();
    }

    /**
     * Generates a drawable graph representation of the state graph.
     * @param type the type of graph representation to generate
     * @param title the title of the graph
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {

       String content = type.generator.generate(processedData.nodes(), processedData.edges(), title, true);

       return new GraphRepresentation(type, content);
    }

    /**
     * Generates a drawable graph representation of the state graph with default title.
     * @param type the type of graph representation to generate
     * @return a diagram code of the state graph
     */
    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
       return getGraph(type, "Graph Diagram", true);
    }

    /**
     * Async Generator for streaming outputs.
     *
     * @param <Output> the type of the output
     */
    public class AsyncNodeGenerator<Output extends NodeOutput> implements AsyncGenerator<Output> {

       final Cursor cursor;

       // String currentNodeId;
       // String nextNodeId;
       int iteration = 0;

       final RunnableConfig config;

       volatile boolean returnFromEmbed = false;

       Map<String, Object> currentState;

       /**
        * The Over all state.
        */
       OverAllState overAllState;

       static class Cursor {

          private String currentNodeId;

          private String nextNodeId;

          private String resumeFrom;

          Cursor() {
             currentNodeId = START;
             nextNodeId = null;
             resumeFrom = null;
          }

          Cursor(Checkpoint cp) {
             currentNodeId = null;
             nextNodeId = cp.getNextNodeId();
             resumeFrom = cp.getNodeId();
          }

          void reset() {
             currentNodeId = null;
             nextNodeId = null;
             resumeFrom = null;
          }

          boolean isResumed() {
             return resumeFrom != null;
          }

          String nextNodeId() {
             return nextNodeId;
          }

          void setNextNodeId(String value) {
             nextNodeId = value;
          }

          String currentNodeId() {
             return currentNodeId;
          }

          void setCurrentNodeId(String value) {
             currentNodeId = value;
          }

          String resumeFrom() {
             return resumeFrom;
          }

          void setResumeFrom(String value) {
             resumeFrom = value;
          }

       }

       /**
        * Instantiates a new Async node generator.
        * @param overAllState the over all state
        * @param config the config
        */
       protected AsyncNodeGenerator(OverAllState overAllState, RunnableConfig config) throws GraphRunnerException {

          if (overAllState.isResume()) {

             log.trace("RESUME REQUEST");

             BaseCheckpointSaver saver = compileConfig.checkpointSaver()
                .orElseThrow(() -> (new IllegalStateException(
                      "inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));
             Checkpoint startCheckpoint = saver.get(config)
                .orElseThrow(() -> (new IllegalStateException("Resume request without a saved checkpoint!")));

             this.currentState = startCheckpoint.getState();
             this.cursor = new Cursor(startCheckpoint);
             this.config = config.withCheckPointId(null);
             this.overAllState = overAllState.input(this.currentState);
             log.trace("RESUME FROM {}", startCheckpoint.getNodeId());

             // this.nextNodeId = startCheckpoint.getNextNodeId();
             // this.currentNodeId = null;
             // log.trace("RESUME FROM {}", startCheckpoint.getNodeId());
          }
          else {

             log.trace("START");
             Map<String, Object> inputs = overAllState.data();
             boolean verify = overAllState.keyVerify();
             if (!CollectionUtils.isEmpty(inputs) && !verify) {
                throw RunnableErrors.initializationError.exception(Arrays.toString(inputs.keySet().toArray()));
             }
             // patch for backward support of AppendableValue
             this.currentState = getInitialState(inputs, config);
             this.overAllState = overAllState.input(currentState);
             this.cursor = new Cursor();
             // this.nextNodeId = null;
             // this.currentNodeId = START;
             this.config = config;
          }
       }

       private Optional<BaseCheckpointSaver.Tag> releaseThread() throws Exception {
          if (compileConfig.releaseThread() && compileConfig.checkpointSaver().isPresent()) {
             return Optional.of(compileConfig.checkpointSaver().get().release(config));
          }
          return Optional.empty();
       }

       /**
        * Build node output output.
        * @param nodeId the node id
        * @return the output
        */
       @SuppressWarnings("unchecked")
       protected Output buildNodeOutput(String nodeId) {
          return (Output) NodeOutput.of(nodeId, cloneState(currentState));
       }

       /**
        * Clone state over all state.
        * @param data the data
        * @return the over all state
        */
       OverAllState cloneState(Map<String, Object> data) {
          return new OverAllState(data, keyStrategyMap, overAllState.isResume());
       }

       /**
        * Build state snapshot output.
        * @param checkpoint the checkpoint
        * @return the output
        */
       @SuppressWarnings("unchecked")
       protected Output buildStateSnapshot(Checkpoint checkpoint) {
          return (Output) StateSnapshot.of(keyStrategyMap, checkpoint, config,
                stateGraph.getStateSerializer().stateFactory());
       }

       /**
        * Gets embed generator from partial state.
        * @param partialState the partial state containing generator instances
        * @return an Optional containing Data with the generator if found, empty
        * otherwise
        */
       private Optional<Data<Output>> getEmbedGenerator(Map<String, Object> partialState) {
          // Extract all AsyncGenerator instances
          List<AsyncGenerator<Output>> asyncNodeGenerators = new ArrayList<>();
          var generatorEntries = partialState.entrySet().stream().filter(e -> {
             // Fixed when parallel nodes return asynchronous generating the same key
             Object value = e.getValue();
             if (value instanceof AsyncGenerator) {
                return true;
             }
             if (value instanceof Collection collection) {
                collection.forEach(o -> {
                   if (o instanceof AsyncGenerator<?>) {
                      asyncNodeGenerators.add((AsyncGenerator<Output>) o);
                   }
                });
             }
             return false;
          }).collect(Collectors.toList());

          if (generatorEntries.isEmpty() && asyncNodeGenerators.isEmpty()) {
             return Optional.empty();
          }

          // Log information about found generators
          if (generatorEntries.size() > 1) {
             log.debug("Multiple generators found: {} - keys: {}", generatorEntries.size(),
                   generatorEntries.stream().map(Map.Entry::getKey).collect(Collectors.joining(", ")));
          }

          // Create appropriate generator (single or merged)
          AsyncGenerator<Output> generator = AsyncGeneratorUtils.createAppropriateGenerator(generatorEntries,
                asyncNodeGenerators, keyStrategyMap);

          // Create data processing logic for the generator
          return Optional.of(Data.composeWith(generator.map(n -> {
             n.setSubGraph(true);
             return n;
          }), data -> processGeneratorOutput(data, partialState, generatorEntries)));
       }

       /**
        * Processes output data from generator.
        * @param data output data from generator
        * @param partialState partial state
        * @param generatorEntries generator entries list
        * @throws Exception if an error occurs during processing
        */
       @SuppressWarnings("unchecked")
       private void processGeneratorOutput(Object data, Map<String, Object> partialState,
             List<Map.Entry<String, Object>> generatorEntries) throws Exception {
          // Remove all generators
          Map<String, Object> partialStateWithoutGenerators = new HashMap<>();
          for (Map.Entry<String, Object> entry : partialState.entrySet()) {
             if (entry.getValue() instanceof AsyncGenerator) {
                continue; // Skip top-level AsyncGenerator values
             }

             if (entry.getValue() instanceof Collection<?>) {
                Collection<?> collection = (Collection<?>) entry.getValue();
                ArrayList<Object> filteredCollection = new ArrayList<>();

                for (Object item : collection) {
                   if (!(item instanceof AsyncGenerator)) {
                      filteredCollection.add(item);
                   }
                }

                if (!filteredCollection.isEmpty()) {
                   partialStateWithoutGenerators.put(entry.getKey(), filteredCollection);
                }
             }
             else {
                // Keep the entry if it's not an AsyncGenerator and not a collection
                // containing it
                partialStateWithoutGenerators.put(entry.getKey(), entry.getValue());
             }
          }

          // Update state with partial state without generators
          var intermediateState = OverAllState.updateState(currentState, partialStateWithoutGenerators,
                keyStrategyMap);
          currentState = intermediateState;
          overAllState.updateState(partialStateWithoutGenerators);

          // If data is not null and is a Map, update state with it
          if (data != null) {
             if (data instanceof Map<?, ?>) {
                currentState = OverAllState.updateState(intermediateState, (Map<String, Object>) data,
                      keyStrategyMap);
                overAllState.updateState((Map<String, Object>) data);

                if (log.isDebugEnabled() && generatorEntries.size() > 1) {
                   log.debug("Updated state with data keys: {}",
                         ((Map<String, Object>) data).keySet().stream().collect(Collectors.joining(", ")));
                }
             }
             else {
                throw new IllegalArgumentException("Embedded generator must return a Map");
             }
          }

          // Get next node command
          var nextNodeCommand = nextNodeId(cursor.currentNodeId(), currentState, config);
          cursor.setNextNodeId(nextNodeCommand.gotoNode());
          currentState = nextNodeCommand.update();
          returnFromEmbed = true;
       }

       private CompletableFuture<Data<Output>> evaluateAction(AsyncNodeActionWithConfig action,
             OverAllState withState) {
          try {
             doListeners(NODEBEFORE, null);
             return action.apply(withState, config).thenApply(updateState -> {
                try {
                   if (action instanceof CommandNode.AsyncCommandNodeActionWithConfig) {
                      AsyncCommandAction commandAction = (AsyncCommandAction) updateState.get("command");
                      Command command = commandAction.apply(withState, config).join();

                      this.currentState = OverAllState.updateState(currentState, command.update(),
                            keyStrategyMap);
                      this.overAllState.updateState(command.update());
                      cursor.setNextNodeId(command.gotoNode());
                      return Data.of(getNodeOutput());
                   }

                   Optional<Data<Output>> embed = getEmbedGenerator(updateState);
                   if (embed.isPresent()) {
                      return embed.get();
                   }

                   this.currentState = OverAllState.updateState(currentState, updateState, keyStrategyMap);
                   this.overAllState.updateState(updateState);
                   if (compileConfig.interruptBeforeEdge()
                         && compileConfig.interruptsAfter().contains(cursor.currentNodeId())) {
                      // nextNodeId = INTERRUPTAFTER;
                      cursor.setNextNodeId(INTERRUPTAFTER);
                   }
                   else {
                      var nextNodeCommand = nextNodeId(cursor.currentNodeId(), currentState, config);
                      // nextNodeId = nextNodeCommand.gotoNode();
                      cursor.setNextNodeId(nextNodeCommand.gotoNode());
                      currentState = nextNodeCommand.update();

                   }

                   return Data.of(getNodeOutput());
                }
                catch (Exception e) {
                   throw new CompletionException(e);
                }
             }).whenComplete((outputData, throwable) -> doListeners(NODEAFTER, null));
          }
          catch (Exception e) {
             return failedFuture(e);
          }

       }

       /**
        * Determines the next node ID based on the current node ID and state.
        * @param nodeId the current node ID
        * @param state the current state
        * @return the next node command
        * @throws Exception if there is an error determining the next node ID
        */
       private Command nextNodeId(String nodeId, Map<String, Object> state, RunnableConfig config) throws Exception {
          return nextNodeId(edges.get(nodeId), state, nodeId, config);

       }

       private Command nextNodeId(EdgeValue route, Map<String, Object> state, String nodeId, RunnableConfig config)
             throws Exception {

          if (route == null) {
             throw RunnableErrors.missingEdge.exception(nodeId);
          }
          if (route.id() != null) {
             return new Command(route.id(), state);
          }
          if (route.value() != null) {

             var command = route.value().action().apply(this.overAllState, config).get();

             var newRoute = command.gotoNode();

             String result = route.value().mappings().get(newRoute);
             if (result == null) {
                throw RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
             }

             var currentState = OverAllState.updateState(state, command.update(), keyStrategyMap);
             this.overAllState.updateState(command.update());
             return new Command(result, currentState);
          }
          throw RunnableErrors.executionError.exception(format("invalid edge value for nodeId: [%s] !", nodeId));
       }

       private CompletableFuture<Output> getNodeOutput() throws Exception {
          Optional<Checkpoint> cp = addCheckpoint(config, cursor.currentNodeId(), currentState, cursor.nextNodeId());
          return completedFuture((cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS)
                ? buildStateSnapshot(cp.get()) : buildNodeOutput(cursor.currentNodeId()));
       }

       @Override
       public Data<Output> next() {
          try {
             // GUARD: CHECK MAX ITERATION REACHED
             if (++iteration > maxIterations) {
                // log.warn( "Maximum number of iterations ({}) reached!",
                // maxIterations);
                return Data.error(new IllegalStateException(
                      format("Maximum number of iterations (%d) reached!", maxIterations)));
             }

             // GUARD: CHECK IF IT IS END
             if (cursor.nextNodeId() == null && cursor.currentNodeId() == null) {
                return releaseThread().map(Data::<Output>done).orElseGet(() -> Data.done(currentState));
             }

             // IS IT A RESUME FROM EMBED ?
             if (returnFromEmbed) {
                final CompletableFuture<Output> future = getNodeOutput();
                returnFromEmbed = false;
                return Data.of(future);
             }

             if (cursor.currentNodeId() != null && config.isInterrupted(cursor.currentNodeId())) {
                config.withNodeResumed(cursor.currentNodeId());
                return Data.done(currentState);
             }

             if (START.equals(cursor.currentNodeId())) {
                doListeners(START, null);
                var nextNodeCommand = getEntryPoint(currentState, config);
                cursor.setNextNodeId(nextNodeCommand.gotoNode());
                currentState = nextNodeCommand.update();

                var cp = addCheckpoint(config, START, currentState, cursor.nextNodeId());

                var output = (cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS)
                      ? buildStateSnapshot(cp.get()) : buildNodeOutput(cursor.currentNodeId());

                cursor.setCurrentNodeId(cursor.nextNodeId());
                // currentNodeId = nextNodeId;

                return Data.of(output);
             }

             if (END.equals(cursor.nextNodeId())) {
                cursor.reset();
                doListeners(END, null);
                // nextNodeId = null;
                // currentNodeId = null;
                return Data.of(buildNodeOutput(END));
             }

             if (cursor.isResumed()) {

                if (compileConfig.interruptBeforeEdge() && Objects.equals(cursor.nextNodeId(), INTERRUPTAFTER)) {
                   var nextNodeCommand = nextNodeId(cursor.resumeFrom(), currentState, config);
                   // nextNodeId = nextNodeCommand.gotoNode();
                   cursor.setNextNodeId(nextNodeCommand.gotoNode());

                   currentState = nextNodeCommand.update();
                   cursor.setCurrentNodeId(null);

                }

                cursor.setResumeFrom(null);

             }

             // check on previous node
             if (shouldInterruptAfter(cursor.currentNodeId(), cursor.nextNodeId())) {
                return Data.done(cursor.currentNodeId());
             }

             if (shouldInterruptBefore(cursor.nextNodeId(), cursor.currentNodeId())) {
                return Data.done(cursor.nextNodeId());
             }

             cursor.setCurrentNodeId(cursor.nextNodeId());
             var action = nodes.get(cursor.currentNodeId());

             if (action == null)
                throw RunnableErrors.missingNode.exception(cursor.currentNodeId());

             return evaluateAction(action, this.overAllState).get();
          }
          catch (Exception e) {
             doListeners(ERROR, e);
             log.error(e.getMessage(), e);
             return Data.error(e);
          }
       }

       private void doListeners(String scene, Exception e) {
          Deque<GraphLifecycleListener> listeners = new LinkedBlockingDeque<>(compileConfig.lifecycleListeners());
          LifeListenerUtil.processListenersLIFO(this.cursor.currentNodeId(), listeners, this.currentState,
                this.config, scene, e);
       }

    }

}

/**
 * The type Processed nodes edges and config.
 */
record ProcessedNodesEdgesAndConfig(StateGraph.Nodes nodes, StateGraph.Edges edges, Set<String> interruptsBefore,
       Set<String> interruptsAfter) {

    /**
     * Instantiates a new Processed nodes edges and config.
     * @param stateGraph the state graph
     * @param config the config
     */
    ProcessedNodesEdgesAndConfig(StateGraph stateGraph, CompileConfig config) {
       this(stateGraph.nodes, stateGraph.edges, config.interruptsBefore(), config.interruptsAfter());
    }

    /**
     * Process processed nodes edges and config.
     * @param stateGraph the state graph
     * @param config the config
     * @return the processed nodes edges and config
     * @throws GraphStateException the graph state exception
     */
    static ProcessedNodesEdgesAndConfig process(StateGraph stateGraph, CompileConfig config)
          throws GraphStateException {

       var subgraphNodes = stateGraph.nodes.onlySubStateGraphNodes();

       if (subgraphNodes.isEmpty()) {
          return new ProcessedNodesEdgesAndConfig(stateGraph, config);
       }

       var interruptsBefore = config.interruptsBefore();
       var interruptsAfter = config.interruptsAfter();
       var nodes = new StateGraph.Nodes(stateGraph.nodes.exceptSubStateGraphNodes());
       var edges = new StateGraph.Edges(stateGraph.edges.elements);

       for (var subgraphNode : subgraphNodes) {

          var sgWorkflow = subgraphNode.subGraph();

          ProcessedNodesEdgesAndConfig processedSubGraph = process(sgWorkflow, config);
          StateGraph.Nodes processedSubGraphNodes = processedSubGraph.nodes;
          StateGraph.Edges processedSubGraphEdges = processedSubGraph.edges;

          //
          // Process START Node
          //
          var sgEdgeStart = processedSubGraphEdges.edgeBySourceId(START).orElseThrow();

          if (sgEdgeStart.isParallel()) {
             throw new GraphStateException("subgraph not support start with parallel branches yet!");
          }

          var sgEdgeStartTarget = sgEdgeStart.target();

          if (sgEdgeStartTarget.id() == null) {
             throw new GraphStateException(format("the target for node '%s' is null!", subgraphNode.id()));
          }

          var sgEdgeStartRealTargetId = subgraphNode.formatId(sgEdgeStartTarget.id());

          // Process Interruption (Before) Subgraph(s)
          interruptsBefore = interruptsBefore.stream()
             .map(interrupt -> Objects.equals(subgraphNode.id(), interrupt) ? sgEdgeStartRealTargetId : interrupt)
             .collect(Collectors.toUnmodifiableSet());

          var edgesWithSubgraphTargetId = edges.edgesByTargetId(subgraphNode.id());

          if (edgesWithSubgraphTargetId.isEmpty()) {
             throw new GraphStateException(
                   format("the node '%s' is not present as target in graph!", subgraphNode.id()));
          }

          for (var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId) {

             var newEdge = edgeWithSubgraphTargetId.withSourceAndTargetIdsUpdated(subgraphNode, Function.identity(),
                   id -> new EdgeValue((Objects.equals(id, subgraphNode.id())
                         ? subgraphNode.formatId(sgEdgeStartTarget.id()) : id)));
             edges.elements.remove(edgeWithSubgraphTargetId);
             edges.elements.add(newEdge);
          }
          //
          // Process END Nodes
          //
          var sgEdgesEnd = processedSubGraphEdges.edgesByTargetId(END);

          var edgeWithSubgraphSourceId = edges.edgeBySourceId(subgraphNode.id()).orElseThrow();

          if (edgeWithSubgraphSourceId.isParallel()) {
             throw new GraphStateException("subgraph not support routes to parallel branches yet!");
          }

          // Process Interruption (After) Subgraph(s)
          if (interruptsAfter.contains(subgraphNode.id())) {

             var exceptionMessage = (edgeWithSubgraphSourceId.target()
                .id() == null) ? "'interruption after' on subgraph is not supported yet!" : format(
                      "'interruption after' on subgraph is not supported yet! consider to use 'interruption before' node: '%s'",
                      edgeWithSubgraphSourceId.target().id());
             throw new GraphStateException(exceptionMessage);
          }

          sgEdgesEnd.stream()
             .map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
                   id -> (Objects.equals(id, END) ? edgeWithSubgraphSourceId.target()
                         : new EdgeValue(subgraphNode.formatId(id)))))
             .forEach(edges.elements::add);
          edges.elements.remove(edgeWithSubgraphSourceId);

          //
          // Process edges
          //
          processedSubGraphEdges.elements.stream()
             .filter(e -> !Objects.equals(e.sourceId(), START))
             .filter(e -> !e.anyMatchByTargetId(END))
             .map(e -> e.withSourceAndTargetIdsUpdated(subgraphNode, subgraphNode::formatId,
                   id -> new EdgeValue(subgraphNode.formatId(id))))
             .forEach(edges.elements::add);

          //
          // Process nodes
          //
          processedSubGraphNodes.elements.stream().map(n -> {
             if (n instanceof CommandNode commandNode) {
                Map<String, String> mappings = commandNode.getMappings();
                HashMap<String, String> newMappings = new HashMap<>();
                mappings.forEach((key, value) -> {
                   newMappings.put(key, subgraphNode.formatId(value));
                });
                return new CommandNode(subgraphNode.formatId(n.id()),
                      AsyncCommandAction.nodeasync((state, config1) -> {
                         Command command = commandNode.getAction().apply(state, config1).join();
                         String NewGoToNode = subgraphNode.formatId(command.gotoNode());
                         return new Command(NewGoToNode, command.update());
                      }), newMappings);
             }
             return n.withIdUpdated(subgraphNode::formatId);
          }).forEach(nodes.elements::add);
       }

       return new ProcessedNodesEdgesAndConfig(nodes, edges, interruptsBefore, interruptsAfter);
    }
}
Node

图结构节点核心类,用于定义图中的单个节点,包括节点的标识符和可选的操作工厂

字段名称
字段类型
描述
id
String
节点的唯一标识符
actionFactory
ActionFactory
用于创建节点执行的动作

对外暴露方法

方法名称
描述
isParallel
检查节点是否为并行节点,当前实现总为false
withIdUpdated
返回一个新节点实例,其ID经过指定函数转换
package com.alibaba.cloud.ai.graph.internal.node;

import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;

import java.util.Objects;
import java.util.function.Function;

import static java.lang.String.format;

/**
 * Represents a node in a graph, characterized by a unique identifier and a factory for
 * creating actions to be executed by the node. This is a generic record where the state
 * type is specified by the type parameter {@code State}.
 *
 * {@link OverAllState}.
 *
 */
public class Node {

    public interface ActionFactory {

       AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;

    }

    private final String id;

    private final ActionFactory actionFactory;

    public Node(String id, ActionFactory actionFactory) {
       this.id = id;
       this.actionFactory = actionFactory;
    }

    /**
     * Constructor that accepts only the `id` and sets `actionFactory` to null.
     * @param id the unique identifier for the node
     */
    public Node(String id) {
       this(id, null);
    }

    /**
     * id
     * @return the unique identifier for the node.
     */
    public String id() {
       return id;
    }

    /**
     * actionFactory
     * @return a factory function that takes a {@link CompileConfig} and returns an
     * {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
     */
    public ActionFactory actionFactory() {
       return actionFactory;
    }

    public boolean isParallel() {
       // return id.startsWith(PARALLELPREFIX);
       return false;
    }

    public Node withIdUpdated(Function<String, String> newId) {
       return new Node(newId.apply(id), actionFactory);
    }

    /**
     * Checks if this node is equal to another object.
     * @param o the object to compare with
     * @return true if this node is equal to the specified object, false otherwise
     */
    @Override
    public boolean equals(Object o) {
       if (this == o)
          return true;
       if (o == null)
          return false;
       if (o instanceof Node node) {
          return Objects.equals(id, node.id);
       }
       return false;

    }

    /**
     * Returns the hash code value for this node.
     * @return the hash code value for this node
     */
    @Override
    public int hashCode() {
       return Objects.hash(id);
    }

    @Override
    public String toString() {
       return format("Node(%s,%s)", id, actionFactory != null ? "action" : "null");
    }

}
Edge

用于定义节点之间的连接关系,是一个 record 类

字段名称
字段类型
描述
sourceId
String
源节点ID,表示边的起始节点
targets
List
边的目标节点或条件

方法名称
描述
构造
Edge
- (String id):创建一个只有源节点的边
- (String sourceId, EdgeValue target):创建一个从源节点到单个目标节点的边
- (String sourceId, List targets):创建一个从源节点到多个目标节点的边
边的判断
isParallel
判断是否为并行边(目标节点数量大于1)
target
单个目标节点,如果为并行边则抛出异常
anyMatchByTargetId
检查目标节点中是否包含指定的目标节点ID
withSourceAndTargetIdsUpdated
更新源节点 ID 和目标节点值,返回新的 Edge 实例
validate
验证边的有效性
package com.alibaba.cloud.ai.graph.internal.edge;

import com.alibaba.cloud.ai.graph.exception.Errors;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.internal.node.Node;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static java.lang.String.format;

/**
 * Represents an edge in a graph with a source ID and a target value.
 *
 * @param sourceId The ID of the source node.
 * @param targets The targets value associated with the edge.
 */
public record Edge(String sourceId, List<EdgeValue> targets) {

    public Edge(String sourceId, EdgeValue target) {
       this(sourceId, List.of(target));
    }

    public Edge(String id) {
       this(id, List.of());
    }

    public boolean isParallel() {
       return targets.size() > 1;
    }

    public EdgeValue target() {
       if (isParallel()) {
          throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
       }
       return targets.get(0);
    }

    public boolean anyMatchByTargetId(String targetId) {
       return targets().stream()
          .anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
                : v.value().mappings().containsValue(targetId)

          );
    }

    public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
          Function<String, EdgeValue> newTarget) {

       var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
       return new Edge(newSourceId.apply(sourceId), newTargets);

    }

    public void validate(StateGraph.Nodes nodes) throws GraphStateException {
       if (!Objects.equals(sourceId(), START) && !nodes.anyMatchById(sourceId())) {
          throw Errors.missingNodeReferencedByEdge.exception(sourceId());
       }

       if (isParallel()) { // check for duplicates targets
          Set<String> duplicates = targets.stream()
             .collect(Collectors.groupingBy(EdgeValue::id, Collectors.counting())) // Group
                                                                   // by
                                                                   // element
                                                                   // and
                                                                   // count
                                                                   // occurrences
             .entrySet()
             .stream()
             .filter(entry -> entry.getValue() > 1) // Filter elements with more than
                                           // one occurrence
             .map(Map.Entry::getKey)
             .collect(Collectors.toSet());
          if (!duplicates.isEmpty()) {
             throw Errors.duplicateEdgeTargetError.exception(sourceId(), duplicates);
          }
       }

       for (EdgeValue target : targets) {
          validate(target, nodes);
       }

    }

    private void validate(EdgeValue target, StateGraph.Nodes nodes) throws GraphStateException {
       if (target.id() != null) {
          if (!Objects.equals(target.id(), StateGraph.END) && !nodes.anyMatchById(target.id())) {
             throw Errors.missingNodeReferencedByEdge.exception(target.id());
          }
       }
       else if (target.value() != null) {
          for (String nodeId : target.value().mappings().values()) {
             if (!Objects.equals(nodeId, StateGraph.END) && !nodes.anyMatchById(nodeId)) {
                throw Errors.missingNodeInEdgeMapping.exception(sourceId(), nodeId);
             }
          }
       }
       else {
          throw Errors.invalidEdgeTarget.exception(sourceId());
       }

    }

    /**
     * Checks if this edge is equal to another object.
     * @param o the object to compare with
     * @return true if this edge is equal to the specified object, false otherwise
     */
    @Override
    public boolean equals(Object o) {
       if (this == o)
          return true;
       if (o == null || getClass() != o.getClass())
          return false;
       Edge node = (Edge) o;
       return Objects.equals(sourceId, node.sourceId);
    }

    /**
     * Returns the hash code value for this edge.
     * @return the hash code value for this edge
     */
    @Override
    public int hashCode() {
       return Objects.hash(sourceId);
    }

}
EdgeValue

图边值的核心类,用于定义边的目标节点或条件,支持条件边

字段名称
字段类型
描述
id
String
目标节点的唯一标识符,用于固定边
value
EdgeCondition
与边值关联的条件,用于条件边

方法名称
描述
构造
EdgeValue
- (String id):创建一个只有目标节点 ID 的 EdgeValue 实例,条件为 null
- (EdgeCondition value):创建一个只有条件的EdgeValue实例,目标节点ID为null
- (String id, EdgeCondition value):创建一个包含目标节点 ID 和条件的 EdgeValue 实例
更新方法
withTargetIdsUpdated(Function
package com.alibaba.cloud.ai.graph.internal.edge;

import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * @param id The unique identifier for the edge value.
 * @param value The condition associated with the edge value.
 */
public record EdgeValue(String id, EdgeCondition value) {

    public EdgeValue(String id) {
       this(id, null);
    }

    public EdgeValue(EdgeCondition value) {
       this(null, value);
    }

    EdgeValue withTargetIdsUpdated(Function<String, EdgeValue> target) {
       if (id != null) {
          return target.apply(id);
       }

       var newMappings = value.mappings().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> {
          var v = target.apply(e.getValue());
          return (v.id() != null) ? v.id() : e.getValue();
       }));

       return new EdgeValue(null, new EdgeCondition(value.action(), newMappings));

    }

}
EdgeCondition

用于定义条件边的条和映射关系

字段名称
字段类型
描述
action
AsyncCommandAction
异步命令动作,用于执行条件判断逻辑
mappings
Map
package com.alibaba.cloud.ai.graph.internal.edge;

import com.alibaba.cloud.ai.graph.action.AsyncCommandAction;

import java.util.Map;

import static java.lang.String.format;

/**
 * Represents a condition associated with an edge in a graph.
 *
 * @param action The action to be performed asynchronously when the edge condition is met.
 * @param mappings A map of string key-value pairs representing additional mappings for
 * the edge condition.
 */
public record EdgeCondition(AsyncCommandAction action, Map<String, String> mappings) {

    @Override
    public String toString() {
       return format("EdgeCondition[ %s, mapping=%s", action != null ? "action" : "null", mappings);
    }

}

学习交流圈

你好,我是影子,曾先后在🐻、新能源、老铁就职,兼任Spring AI Alibaba开源社区的Committer。目前新建了一个交流群,一个人走得快,一群人走得远,关注公众号后可获得个人微信,添加微信后备注“交流”入群。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取

Spring AI Alibaba 主要聚焦于简化 Java 开发者在人工智能应用开发中的复杂度,提供高层次的 AI API 抽象和与云原生基础设施的深度集成方案。尽管其核心功能主要围绕自然语言处理、图像生成、语音合成等 AI 服务展开[^3],但其设计哲学和架构模式也为集成更广泛的 AI 技术(如图计算)提供了可能性。 图计算是一种处理图结构数据的计算范式,广泛应用于社交网络分析、推荐系统、知识图谱等领域。虽然目前提供的引用中并未明确提及 Spring AI Alibaba 对图计算技术的直接支持[^1],但可以从以下几个方面推测其潜在的应用或集成方式: ### 图计算技术的潜在集成路径 1. **基于统一 API 的抽象集成** Spring AI Alibaba 提供了统一的 API 接口,使得开发者能够以一致的方式调用不同类型的 AI 服务[^3]。这种设计思想可以扩展到图计算领域,例如通过定义图计算服务的抽象接口(如 `GraphComputationService`),并实现对不同图计算引擎(如 Apache Giraph、GraphX、阿里云图计算服务)的适配。这样,开发者可以通过简单的配置切换底层图计算引擎。 2. **与 Spring Cloud Alibaba 生态的融合** Spring AI AlibabaSpring Cloud Alibaba 的深度集成使得其能够借助后者的服务发现、配置管理、负载均衡等能力[^4]。在图计算场景中,这种能力可以用于构建分布式的图计算服务,支持图数据的分片处理、任务调度与结果聚合。 3. **本地部署与开源生态支持** Spring AI Alibaba 支持本地部署及与主流 AI 框架的集成[^3]。对于图计算来说,这意味着可以与如 Neo4j、JanusGraph 等图数据库进行集成,或者与 Apache TinkerPop 等图计算框架结合,构建完整的图数据处理与分析流水线。 4. **AI 与图计算的联合应用** 在推荐系统、社交网络分析等场景中,AI 与图计算常常需要联合使用。例如,使用图计算挖掘用户之间的关系网络,再结合 AI 模型进行用户兴趣预测。Spring AI Alibaba 可以通过其 AI 服务与图计算服务的协同调用,提供端到端的解决方案。 ### 示例:图计算服务的抽象接口定义 ```java public interface GraphComputationService { /** * 执行图计算任务 * @param graphData 图数据(如邻接表或边列表) * @param algorithm 计算算法(如 PageRank、最短路径等) * @return 计算结果 */ Map<String, Object> computeGraph(String graphData, String algorithm); } ``` ### 示例:基于阿里云图计算服务的实现 ```java @Service public class AliyunGraphComputationServiceImpl implements GraphComputationService { private final GraphComputeClient graphComputeClient; public AliyunGraphComputationServiceImpl(GraphComputeClient graphComputeClient) { this.graphComputeClient = graphComputeClient; } @Override public Map<String, Object> computeGraph(String graphData, String algorithm) { // 调用阿里云图计算服务 return graphComputeClient.invoke(graphData, algorithm); } } ``` ### 示例:通过配置切换图计算引擎 ```yaml graph: computation: provider: aliyun # 可选值:aliyun, neo4j, giraph 等 ``` ### 展望与建议 尽管目前 Spring AI Alibaba核心功能集中在通义系列大模型的集成,但其模块化设计和高度抽象的 API 为图计算技术的集成提供了良好的基础。开发者可以基于现有框架进行扩展,或期待未来官方对图计算领域的进一步支持。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值