diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java new file mode 100644 index 00000000..05e45475 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java @@ -0,0 +1,84 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; + +import java.util.List; + +/** + * A snapshot of the metrics tracked so far by an {@link AIGraphTracker}. + *

+ * All fields are nullable: a {@code null} value means the corresponding metric has not been + * recorded yet on the tracker. {@link #getResumptionToken()} is always present. + *

+ * Instances are immutable. + */ +public final class AIGraphMetricSummary { + private final Boolean success; + private final Double durationMs; + private final TokenUsage tokens; + private final List path; + private final String resumptionToken; + + AIGraphMetricSummary( + Boolean success, + Double durationMs, + TokenUsage tokens, + List path, + String resumptionToken) { + this.success = success; + this.durationMs = durationMs; + this.tokens = tokens; + this.path = path; + this.resumptionToken = resumptionToken; + } + + /** + * Returns the invocation outcome: {@code true} if {@code trackInvocationSuccess} was called, + * {@code false} if {@code trackInvocationFailure} was called, or {@code null} if neither has + * been called yet. + * + * @return the success flag, or {@code null} if not yet recorded + */ + public Boolean getSuccess() { + return success; + } + + /** + * Returns the tracked graph-level duration in milliseconds, or {@code null} if not recorded. + * + * @return the duration in ms, or {@code null} + */ + public Double getDurationMs() { + return durationMs; + } + + /** + * Returns the tracked token usage, or {@code null} if not recorded. + * + * @return the token usage, or {@code null} + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the tracked node path (ordered list of node keys visited), or {@code null} if not + * recorded. + * + * @return an unmodifiable list of node keys, or {@code null} + */ + public List getPath() { + return path; + } + + /** + * Returns the resumption token for this graph run, which can be passed to + * {@link LDAIClient#createGraphTracker(String, com.launchdarkly.sdk.LDContext)} to reconstruct + * the tracker on a subsequent request. + * + * @return the resumption token; never {@code null} + */ + public String getResumptionToken() { + return resumptionToken; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java new file mode 100644 index 00000000..ec234306 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java @@ -0,0 +1,314 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.logging.LDLogAdapter; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LDSLF4J; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.ArrayBuilder; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.internal.ResumptionTokens; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Reports graph-level events for a single invocation of an {@link AgentGraphDefinition}. + *

+ * An {@code AIGraphTracker} is obtained from an enabled graph definition via + * {@link AgentGraphDefinition#createTracker()}, or reconstructed from a resumption token via + * {@link LDAIClient#createGraphTracker(String, LDContext)}. + *

+ * Graph-level methods (invocation, duration, tokens, path) are at-most-once: a second call on + * the same tracker is silently dropped. Edge-level methods (redirect, handoff) are multi-fire — + * each call records a distinct event. + *

+ * Implementations are thread-safe. + */ +public final class AIGraphTracker { + + private static final String GRAPH_INVOCATION_SUCCESS = "$ld:ai:graph:invocation_success"; + private static final String GRAPH_INVOCATION_FAILURE = "$ld:ai:graph:invocation_failure"; + private static final String GRAPH_DURATION_TOTAL = "$ld:ai:graph:duration:total"; + private static final String GRAPH_TOTAL_TOKENS = "$ld:ai:graph:total_tokens"; + private static final String GRAPH_PATH = "$ld:ai:graph:path"; + private static final String GRAPH_REDIRECT = "$ld:ai:graph:redirect"; + private static final String GRAPH_HANDOFF_SUCCESS = "$ld:ai:graph:handoff_success"; + private static final String GRAPH_HANDOFF_FAILURE = "$ld:ai:graph:handoff_failure"; + + private final LDClientInterface client; + private final LDContext context; + private final LDLogger logger; + + private final String runId; + private final String graphKey; + private final String variationKey; + private final int version; + + private final String resumptionToken; + + // At-most-once guards: null = not yet recorded, non-null = recorded. + // trackInvocationSuccess and trackInvocationFailure share invocationRecorded: + // true = success was recorded, false = failure was recorded. + private final AtomicReference invocationRecorded = new AtomicReference<>(); + private final AtomicReference durationRecorded = new AtomicReference<>(); + private final AtomicReference tokensRecorded = new AtomicReference<>(); + private final AtomicReference> pathRecorded = new AtomicReference<>(); + + AIGraphTracker( + LDClientInterface client, + String runId, + String graphKey, + String variationKey, + int version, + LDContext context, + LDLogger logger) { + this.client = Objects.requireNonNull(client, "client"); + this.runId = Objects.requireNonNull(runId, "runId"); + this.graphKey = Objects.requireNonNull(graphKey, "graphKey"); + this.variationKey = variationKey; + this.version = version; + this.context = Objects.requireNonNull(context, "context"); + this.logger = Objects.requireNonNull(logger, "logger"); + + this.resumptionToken = ResumptionTokens.encodeGraph(runId, graphKey, variationKey, version); + } + + /** + * Reconstructs a graph tracker from a resumption token, preserving the original run identity. + * + * @param token the resumption token produced by {@link #getResumptionToken()} + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static AIGraphTracker fromResumptionToken( + String token, LDClientInterface client, LDContext context) { + return fromResumptionToken(token, client, context, defaultLogger()); + } + + /** + * Reconstructs a graph tracker from a resumption token, preserving the original run identity, + * and logging through the supplied logger. + * + * @param token the resumption token produced by {@link #getResumptionToken()} + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @param logger the logger to use for at-most-once warnings; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static AIGraphTracker fromResumptionToken( + String token, LDClientInterface client, LDContext context, LDLogger logger) { + ResumptionTokens.DecodedGraph d = ResumptionTokens.decodeGraph(token); + int version = Math.max(1, d.getVersion()); + return new AIGraphTracker( + client, + d.getRunId(), + d.getGraphKey(), + d.getVariationKey(), + version, + context, + logger); + } + + /** + * Records that the graph invocation succeeded. + *

+ * At-most-once and mutually exclusive with {@link #trackInvocationFailure()}: whichever is + * called first wins. + */ + public void trackInvocationSuccess() { + if (!invocationRecorded.compareAndSet(null, Boolean.TRUE)) { + logger.warn("Skipping trackInvocationSuccess: invocation already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_INVOCATION_SUCCESS, context, baseData().build(), 1); + } + + /** + * Records that the graph invocation failed. + *

+ * At-most-once and mutually exclusive with {@link #trackInvocationSuccess()}: whichever is + * called first wins. + */ + public void trackInvocationFailure() { + if (!invocationRecorded.compareAndSet(null, Boolean.FALSE)) { + logger.warn("Skipping trackInvocationFailure: invocation already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_INVOCATION_FAILURE, context, baseData().build(), 1); + } + + /** + * Records the total wall-clock duration of the graph invocation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param durationMs the duration in milliseconds + */ + public void trackDuration(double durationMs) { + if (!durationRecorded.compareAndSet(null, durationMs)) { + logger.warn("Skipping trackDuration: duration already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_DURATION_TOTAL, context, baseData().build(), durationMs); + } + + /** + * Records the total token usage for the graph invocation. + *

+ * At-most-once: subsequent calls are silently dropped. Calls where all counts are zero do not + * consume the at-most-once slot. + * + * @param tokens the token usage; ignored if {@code null} + */ + public void trackTotalTokens(TokenUsage tokens) { + if (tokens == null) { + logger.debug("Skipping trackTotalTokens: tokens was null."); + return; + } + boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; + if (!hasPositive) { + return; + } + if (!tokensRecorded.compareAndSet(null, tokens)) { + logger.warn("Skipping trackTotalTokens: token usage already recorded on this graph tracker."); + return; + } + if (tokens.getTotal() > 0) { + client.trackMetric(GRAPH_TOTAL_TOKENS, context, baseData().build(), tokens.getTotal()); + } + } + + /** + * Records the ordered path of node keys visited during the graph invocation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param path the ordered list of node keys; ignored if {@code null} or empty + */ + public void trackPath(List path) { + if (path == null || path.isEmpty()) { + logger.debug("Skipping trackPath: path was null or empty."); + return; + } + List snapshot = Collections.unmodifiableList(new ArrayList<>(path)); + if (!pathRecorded.compareAndSet(null, snapshot)) { + logger.warn("Skipping trackPath: path already recorded on this graph tracker."); + return; + } + ArrayBuilder ab = LDValue.buildArray(); + for (String s : path) { + ab.add(LDValue.of(s)); + } + LDValue data = baseData().put("path", ab.build()).build(); + client.trackMetric(GRAPH_PATH, context, data, 1); + } + + /** + * Records a redirect event, where the graph transitioned from one node to a different target + * than the edge originally specified. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param redirectedTarget the key of the node that was actually used + */ + public void trackRedirect(String sourceKey, String redirectedTarget) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("redirectedTarget", redirectedTarget) + .build(); + client.trackMetric(GRAPH_REDIRECT, context, data, 1); + } + + /** + * Records a successful handoff from one node to another. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param targetKey the key of the target node + */ + public void trackHandoffSuccess(String sourceKey, String targetKey) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("targetKey", targetKey) + .build(); + client.trackMetric(GRAPH_HANDOFF_SUCCESS, context, data, 1); + } + + /** + * Records a failed handoff from one node to another. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param targetKey the key of the target node + */ + public void trackHandoffFailure(String sourceKey, String targetKey) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("targetKey", targetKey) + .build(); + client.trackMetric(GRAPH_HANDOFF_FAILURE, context, data, 1); + } + + /** + * Returns a snapshot of all graph-level metrics tracked so far on this tracker. + * + * @return the metric summary; never {@code null} + */ + public AIGraphMetricSummary getSummary() { + return new AIGraphMetricSummary( + invocationRecorded.get(), + durationRecorded.get(), + tokensRecorded.get(), + pathRecorded.get(), + resumptionToken); + } + + /** + * Returns the resumption token for this graph run. + *

+ * The token encodes the run identity and can be passed to + * {@link LDAIClient#createGraphTracker(String, LDContext)} to reconstruct the tracker across + * requests. + * + * @return the resumption token; never {@code null} + */ + public String getResumptionToken() { + return resumptionToken; + } + + private ObjectBuilder baseData() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("graphKey", graphKey) + .put("version", version); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + return b; + } + + private static LDLogger defaultLogger() { + LDLogAdapter adapter; + try { + Class.forName("org.slf4j.LoggerFactory"); + adapter = LDSLF4J.adapter(); + } catch (ClassNotFoundException e) { + adapter = Logs.toConsole(); + } + return LDLogger.withAdapter(adapter, "LaunchDarkly.AI"); + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java new file mode 100644 index 00000000..b18b68a2 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java @@ -0,0 +1,289 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * The fully resolved definition of an agent graph, containing all nodes and their edges. + *

+ * An {@code AgentGraphDefinition} is obtained from {@link LDAIClient#agentGraph}. When + * {@link #isEnabled()} returns {@code false}, the graph definition was not fetchable or failed + * validation; in that case all node collections are empty and traversal methods are no-ops. + *

+ * Traversal methods ({@link #traverse} and {@link #reverseTraverse}) are BFS-based and + * cycle-safe: each node is visited at most once. + *

+ * This class is thread-safe. All returned collections are unmodifiable. + */ +public final class AgentGraphDefinition { + private final AgentGraphFlagValue flagValue; + private final Map nodes; + private final boolean enabled; + private final Supplier trackerFactory; + + AgentGraphDefinition( + AgentGraphFlagValue flagValue, + Map nodes, + boolean enabled, + Supplier trackerFactory) { + this.flagValue = flagValue; + this.nodes = nodes; + this.enabled = enabled; + this.trackerFactory = trackerFactory; + } + + /** + * Returns {@code true} if this graph definition is enabled and all nodes were successfully + * fetched. + * + * @return whether the graph is enabled + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Returns the root node of the graph (the entry point). + * + * @return the root node, or {@code null} if the graph is disabled or the root key is not in the + * node map + */ + public AgentGraphNode rootNode() { + return nodes.get(flagValue.getRoot()); + } + + /** + * Returns the node with the given key, or {@code null} if not found. + * + * @param nodeKey the node key to look up + * @return the node, or {@code null} + */ + public AgentGraphNode getNode(String nodeKey) { + return nodes.get(nodeKey); + } + + /** + * Returns the immediate child nodes of the node with the given key, following its outgoing + * edges. + * + * @param nodeKey the source node key + * @return an unmodifiable list of child nodes; empty if the node is terminal or not found + */ + public List getChildNodes(String nodeKey) { + AgentGraphNode node = nodes.get(nodeKey); + if (node == null) { + return Collections.emptyList(); + } + List children = new ArrayList<>(); + for (GraphEdge edge : node.getEdges()) { + AgentGraphNode child = nodes.get(edge.getKey()); + if (child != null) { + children.add(child); + } + } + return Collections.unmodifiableList(children); + } + + /** + * Returns all nodes that have an outgoing edge pointing to the given node key. + * + * @param nodeKey the target node key + * @return an unmodifiable list of parent nodes; empty if none found + */ + public List getParentNodes(String nodeKey) { + List parents = new ArrayList<>(); + for (AgentGraphNode node : nodes.values()) { + for (GraphEdge edge : node.getEdges()) { + if (nodeKey.equals(edge.getKey())) { + parents.add(node); + break; + } + } + } + return Collections.unmodifiableList(parents); + } + + /** + * Returns all terminal nodes (nodes with no outgoing edges). + * + * @return an unmodifiable list of terminal nodes; empty if the graph is disabled + */ + public List terminalNodes() { + List terminals = new ArrayList<>(); + for (AgentGraphNode node : nodes.values()) { + if (node.isTerminal()) { + terminals.add(node); + } + } + return Collections.unmodifiableList(terminals); + } + + /** + * Returns the internal parsed flag value for this graph. This is an internal type and is not + * part of the supported public API. + * + * @return the parsed flag value + */ + AgentGraphFlagValue getConfig() { + return flagValue; + } + + /** + * Creates a new {@link AIGraphTracker} for this graph invocation. + *

+ * Each call produces a fresh tracker with a new run ID. Returns {@code null} if the graph is + * disabled. + * + * @return a new tracker, or {@code null} if disabled + */ + public AIGraphTracker createTracker() { + if (!enabled || trackerFactory == null) { + return null; + } + return trackerFactory.get(); + } + + /** + * Performs a BFS traversal of the graph starting from the root node. + *

+ * For each node visited, {@code fn} is called with the node and the mutable context map. The + * return value of {@code fn} is stored in the context map under the node's key, making it + * available to subsequently visited nodes. Each node is visited exactly once (cycle-safe). + *

+ * This is a no-op when the graph is disabled or the root node is absent. + * + * @param fn the visitor function; receives the current node and the context map + * @param ctx the mutable context map; values from earlier nodes are available to later ones + */ + public void traverse(BiFunction, Object> fn, + Map ctx) { + AgentGraphNode root = rootNode(); + if (root == null) { + return; + } + + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + visited.add(root.getKey()); + queue.add(root); + + while (!queue.isEmpty()) { + AgentGraphNode node = queue.poll(); + Object result = fn.apply(node, ctx); + ctx.put(node.getKey(), result); + + for (AgentGraphNode child : getChildNodes(node.getKey())) { + if (visited.add(child.getKey())) { + queue.add(child); + } + } + } + } + + /** + * Performs a reverse BFS traversal of the graph, starting from terminal nodes and working + * upward toward the root. + *

+ * The root node is always processed last. Each node is visited exactly once (cycle-safe). This + * is a no-op when the graph is disabled or there are no terminal nodes. + * + * @param fn the visitor function; receives the current node and the context map + * @param ctx the mutable context map; values from earlier nodes are available to later ones + */ + public void reverseTraverse(BiFunction, Object> fn, + Map ctx) { + AgentGraphNode root = rootNode(); + if (root == null) { + return; + } + + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + + // Seed from terminals, excluding root (it will be processed last). + for (AgentGraphNode terminal : terminalNodes()) { + if (!terminal.getKey().equals(root.getKey()) && visited.add(terminal.getKey())) { + queue.add(terminal); + } + } + + while (!queue.isEmpty()) { + AgentGraphNode node = queue.poll(); + Object result = fn.apply(node, ctx); + ctx.put(node.getKey(), result); + + for (AgentGraphNode parent : getParentNodes(node.getKey())) { + if (!parent.getKey().equals(root.getKey()) && visited.add(parent.getKey())) { + queue.add(parent); + } + } + } + + // Process root last (whether or not it was encountered as a parent above). + if (visited.add(root.getKey())) { + Object result = fn.apply(root, ctx); + ctx.put(root.getKey(), result); + } + } + + /** + * Builds the node map from the parsed flag value and pre-fetched agent configs. + *

+ * For each key in {@link #collectAllKeys}, looks up the agent config from {@code configs} and + * the outgoing edges from the flag value's edge map. Returns an unmodifiable map. + * + * @param flagValue the parsed flag value + * @param configs the pre-fetched agent configs keyed by node key + * @return an unmodifiable map of nodes keyed by config key + */ + static Map buildNodes( + AgentGraphFlagValue flagValue, Map configs) { + Set allKeys = collectAllKeys(flagValue); + Map result = new HashMap<>(); + for (String key : allKeys) { + AIAgentConfig config = configs.get(key); + if (config == null) { + continue; + } + List edges = flagValue.getEdges().get(key); + if (edges == null) { + edges = Collections.emptyList(); + } + result.put(key, new AgentGraphNode(key, config, edges)); + } + return Collections.unmodifiableMap(result); + } + + /** + * Collects all unique node keys referenced anywhere in the flag value: the root key, all edge + * source keys, and all edge target keys. + * + * @param flagValue the parsed flag value + * @return the set of all unique node keys + */ + static Set collectAllKeys(AgentGraphFlagValue flagValue) { + Set keys = new HashSet<>(); + String root = flagValue.getRoot(); + if (root != null && !root.isEmpty()) { + keys.add(root); + } + for (Map.Entry> entry : flagValue.getEdges().entrySet()) { + keys.add(entry.getKey()); + for (GraphEdge edge : entry.getValue()) { + keys.add(edge.getKey()); + } + } + return keys; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java new file mode 100644 index 00000000..bea5cc05 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java @@ -0,0 +1,59 @@ +package com.launchdarkly.sdk.server.ai; + +import java.util.Collections; +import java.util.List; + +/** + * A node in an {@link AgentGraphDefinition}, wrapping a single agent config and its outgoing + * edges to other nodes. + *

+ * Nodes are retrieved from a graph definition via {@link AgentGraphDefinition#getNode(String)}, + * {@link AgentGraphDefinition#rootNode()}, etc. Instances are immutable. + */ +public final class AgentGraphNode { + private final String key; + private final AIAgentConfig config; + private final List edges; + + AgentGraphNode(String key, AIAgentConfig config, List edges) { + this.key = key; + this.config = config; + this.edges = edges == null ? Collections.emptyList() : edges; + } + + /** + * Returns the AI Config key identifying this node. + * + * @return the node key; never {@code null} + */ + public String getKey() { + return key; + } + + /** + * Returns the retrieved agent config for this node. + * + * @return the agent config; never {@code null} + */ + public AIAgentConfig getConfig() { + return config; + } + + /** + * Returns the outgoing edges from this node to other nodes. + * + * @return an unmodifiable list of outgoing edges; never {@code null} but may be empty + */ + public List getEdges() { + return edges; + } + + /** + * Returns {@code true} if this node has no outgoing edges. + * + * @return {@code true} if terminal (no edges), {@code false} otherwise + */ + public boolean isTerminal() { + return edges.isEmpty(); + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java new file mode 100644 index 00000000..116d3abb --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java @@ -0,0 +1,46 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.LDValue; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * An edge in an agent graph, representing a directed connection from one node to a target node. + *

+ * Each edge carries the key of the target {@link AgentGraphNode} and an optional handoff map of + * arbitrary data that may be passed when transitioning to the target node. Instances are immutable. + */ +public final class GraphEdge { + private final String key; + private final Map handoff; + + public GraphEdge(String key, Map handoff) { + this.key = key; + this.handoff = handoff != null + ? Collections.unmodifiableMap(new LinkedHashMap<>(handoff)) + : null; + } + + /** + * Returns the key of the target node that this edge points to. + * + * @return the target node key; never {@code null} + */ + public String getKey() { + return key; + } + + /** + * Returns the handoff options for this edge. + *

+ * The handoff is an optional map of arbitrary values that may be passed when transitioning + * to the target node. If no handoff was defined for this edge, returns {@code null}. + * + * @return an unmodifiable map of handoff values, or {@code null} if absent + */ + public Map getHandoff() { + return handoff; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index 77031cf4..d0734ba0 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -2,6 +2,7 @@ import com.launchdarkly.sdk.LDContext; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -100,4 +101,45 @@ AIJudgeConfig judgeConfig( * @throws IllegalArgumentException if the token is malformed */ LDAIConfigTracker createTracker(String resumptionToken, LDContext context); + + /** + * Fetches and validates an agent graph definition identified by {@code graphKey}. + *

+ * Evaluates the graph flag, fetches all referenced node configs, and validates the graph + * structure. If validation fails (disabled flag, empty root, unreachable nodes, or any + * non-enabled child config) the returned definition has {@link AgentGraphDefinition#isEnabled()} + * {@code == false} and an empty node map. + *

+ * Also emits a {@code $ld:ai:usage:agent-graph} usage event. + * + * @param graphKey the flag key identifying the agent graph + * @param context the evaluation context + * @param variables Mustache template variables applied to each node's instructions + * @return the resolved graph definition; never {@code null} + */ + AgentGraphDefinition agentGraph(String graphKey, LDContext context, Map variables); + + /** + * Fetches and validates an agent graph definition with no template variables. + * + * @param graphKey the flag key identifying the agent graph + * @param context the evaluation context + * @return the resolved graph definition; never {@code null} + */ + default AgentGraphDefinition agentGraph(String graphKey, LDContext context) { + return agentGraph(graphKey, context, Collections.emptyMap()); + } + + /** + * Reconstructs an {@link AIGraphTracker} from a resumption token. + *

+ * Use this to continue tracking a graph run across requests by passing the token produced by + * {@link AIGraphTracker#getResumptionToken()}. + * + * @param resumptionToken the token produced by a prior {@link AIGraphTracker} + * @param context the evaluation context + * @return a reconstructed tracker with the original run identity + * @throws IllegalArgumentException if the token is malformed + */ + AIGraphTracker createGraphTracker(String resumptionToken, LDContext context); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 8bf81e71..02581650 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -10,6 +10,7 @@ import com.launchdarkly.sdk.LDValueType; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigParser; import com.launchdarkly.sdk.server.ai.internal.AISdkInfo; @@ -18,10 +19,16 @@ import com.launchdarkly.sdk.server.interfaces.LDClientInterface; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Queue; +import java.util.Set; import java.util.UUID; import java.util.function.Supplier; @@ -45,6 +52,7 @@ public final class LDAIClientImpl implements LDAIClient { private static final String TRACK_USAGE_AGENT_CONFIG = "$ld:ai:usage:agent-config"; private static final String TRACK_USAGE_AGENT_CONFIGS = "$ld:ai:usage:agent-configs"; private static final String TRACK_USAGE_JUDGE_CONFIG = "$ld:ai:usage:judge-config"; + private static final String TRACK_USAGE_AGENT_GRAPH = "$ld:ai:usage:agent-graph"; private static final LDContext INIT_TRACK_CONTEXT = LDContext .builder("ld-internal-tracking") @@ -110,21 +118,26 @@ public AIAgentConfig agentConfig( @Override public Map agentConfigs( List agentConfigs, LDContext context) { - Map result = new LinkedHashMap<>(); int count = 0; if (agentConfigs != null) { for (AIAgentConfigRequest request : agentConfigs) { - if (request == null) { - continue; + if (request != null) { + count++; } - count++; - result.put( - request.getKey(), - evaluateAgent(request.getKey(), context, request.getDefaultValue(), request.getVariables())); } } client.trackMetric(TRACK_USAGE_AGENT_CONFIGS, context, LDValue.of(count), count); + Map result = new LinkedHashMap<>(); + if (agentConfigs != null) { + for (AIAgentConfigRequest request : agentConfigs) { + if (request != null) { + result.put( + request.getKey(), + evaluateAgent(request.getKey(), context, request.getDefaultValue(), request.getVariables())); + } + } + } return result; } @@ -142,9 +155,15 @@ public AIJudgeConfig judgeConfig( private AIAgentConfig evaluateAgent( String key, LDContext context, AIAgentConfigDefault defaultValue, Map variables) { + return evaluateAgent(key, context, defaultValue, variables, null); + } + + private AIAgentConfig evaluateAgent( + String key, LDContext context, AIAgentConfigDefault defaultValue, + Map variables, String graphKey) { AIAgentConfigDefault effectiveDefault = defaultValue != null ? defaultValue : AIAgentConfigDefault.disabled(); - return (AIAgentConfig) evaluate(key, context, effectiveDefault, Mode.AGENT, variables); + return (AIAgentConfig) evaluate(key, context, effectiveDefault, Mode.AGENT, variables, graphKey); } /** @@ -158,13 +177,23 @@ private AIConfig evaluate( AIConfigDefault defaultValue, Mode mode, Map variables) { + return evaluate(key, context, defaultValue, mode, variables, null); + } + + private AIConfig evaluate( + String key, + LDContext context, + AIConfigDefault defaultValue, + Mode mode, + Map variables, + String graphKey) { LDValue value = client.jsonValueVariation(key, context, LDValue.ofNull()); // A valid AI Config variation is always a JSON object (it carries the _ldMeta block). When the // flag is absent or cannot be evaluated the base SDK hands back our null sentinel; in that case // we return the caller's typed default directly rather than serializing it and parsing it back. if (value == null || value.getType() != LDValueType.OBJECT) { - return buildConfigFromDefault(key, mode, defaultValue, context, variables); + return buildConfigFromDefault(key, mode, defaultValue, context, variables, graphKey); } AIConfigFlagValue parsed = AIConfigParser.parse(value); @@ -174,10 +203,10 @@ private AIConfig evaluate( logger.warn( "AI Config mode mismatch for {}: expected {}, got {}. Returning default config.", key, mode.getWireValue(), flagMode.getWireValue()); - return buildConfigFromDefault(key, mode, defaultValue, context, variables); + return buildConfigFromDefault(key, mode, defaultValue, context, variables, graphKey); } - return buildConfig(key, mode, parsed, context, variables); + return buildConfig(key, mode, parsed, context, variables, graphKey); } private AIConfig buildConfig( @@ -186,9 +215,19 @@ private AIConfig buildConfig( AIConfigFlagValue parsed, LDContext context, Map variables) { + return buildConfig(key, mode, parsed, context, variables, null); + } + + private AIConfig buildConfig( + String key, + Mode mode, + AIConfigFlagValue parsed, + LDContext context, + Map variables, + String graphKey) { Supplier factory = trackerFactory( key, parsed.getVariationKey(), parsed.getVersion(), - parsed.getModel(), parsed.getProvider(), context); + parsed.getModel(), parsed.getProvider(), context, graphKey); switch (mode) { case AGENT: return new AIAgentConfig( @@ -233,9 +272,19 @@ private AIConfig buildConfigFromDefault( AIConfigDefault defaultValue, LDContext context, Map variables) { + return buildConfigFromDefault(key, mode, defaultValue, context, variables, null); + } + + private AIConfig buildConfigFromDefault( + String key, + Mode mode, + AIConfigDefault defaultValue, + LDContext context, + Map variables, + String graphKey) { // Default configs still get real trackers — the configKey was requested even if no flag was found. // variationKey is null because no flag evaluation occurred. - Supplier factory = trackerFactory(key, null, null, null, null, context); + Supplier factory = trackerFactory(key, null, null, null, null, context, graphKey); switch (mode) { case AGENT: { AIAgentConfigDefault agent = (AIAgentConfigDefault) defaultValue; @@ -287,6 +336,17 @@ private Supplier trackerFactory( com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, LDContext context) { + return trackerFactory(configKey, variationKey, version, model, provider, context, null); + } + + private Supplier trackerFactory( + String configKey, + String variationKey, + Integer version, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, + LDContext context, + String graphKey) { String modelName = model != null && model.getName() != null ? model.getName() : ""; String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; int ver = version != null ? version : 1; @@ -299,8 +359,92 @@ private Supplier trackerFactory( modelName, providerName, context, - null, // graphKey — set by agentGraph() in Plan 3 + graphKey, + logger); + } + + @Override + public AgentGraphDefinition agentGraph( + String graphKey, LDContext context, Map variables) { + client.trackMetric(TRACK_USAGE_AGENT_GRAPH, context, LDValue.of(graphKey), 1); + + LDValue flagValue = client.jsonValueVariation(graphKey, context, LDValue.ofNull()); + AgentGraphFlagValue parsed = AgentGraphFlagValue.parse( + (flagValue != null && flagValue.getType() == LDValueType.OBJECT) ? flagValue : null); + + Map effectiveVars = variables != null ? variables : Collections.emptyMap(); + Supplier trackerFactory = () -> new AIGraphTracker( + client, + UUID.randomUUID().toString(), + graphKey, + parsed.getVariationKey(), + parsed.getVersion(), + context, logger); + + AgentGraphDefinition disabled = new AgentGraphDefinition( + parsed, Collections.emptyMap(), false, trackerFactory); + + // Validation step 1: _ldMeta.enabled + if (!parsed.isEnabled()) { + return disabled; + } + + // Validation step 2: root must be non-empty + String root = parsed.getRoot(); + if (root == null || root.isEmpty()) { + return disabled; + } + + // Validation step 3: all keys reachable from root (no unconnected nodes) + Set allKeys = AgentGraphDefinition.collectAllKeys(parsed); + Set reachableKeys = collectReachableKeys(parsed); + for (String key : allKeys) { + if (!reachableKeys.contains(key)) { + return disabled; + } + } + + // Validation step 4: fetch each child config (without emitting usage events) + Map configs = new HashMap<>(); + for (String key : allKeys) { + AIAgentConfig config = evaluateAgent(key, context, null, effectiveVars, graphKey); + if (!config.isEnabled()) { + return disabled; + } + configs.put(key, config); + } + + Map nodes = AgentGraphDefinition.buildNodes(parsed, configs); + return new AgentGraphDefinition(parsed, nodes, true, trackerFactory); + } + + private static Set collectReachableKeys(AgentGraphFlagValue flagValue) { + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + String root = flagValue.getRoot(); + if (root == null || root.isEmpty()) { + return visited; + } + visited.add(root); + queue.add(root); + while (!queue.isEmpty()) { + String key = queue.poll(); + List edges = flagValue.getEdges().get(key); + if (edges != null) { + for (GraphEdge edge : edges) { + if (visited.add(edge.getKey())) { + queue.add(edge.getKey()); + } + } + } + } + return visited; + } + + @Override + public AIGraphTracker createGraphTracker(String resumptionToken, LDContext context) { + return AIGraphTracker.fromResumptionToken(resumptionToken, client, context, logger); } @Override diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java new file mode 100644 index 00000000..335983d1 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java @@ -0,0 +1,220 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.LDValueType; +import com.launchdarkly.sdk.server.ai.GraphEdge; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * The parsed, strongly-typed representation of an agent graph flag variation's JSON protocol. + *

+ * Mirrors the wire structure: the {@code _ldMeta} block (enabled / variationKey / version) plus + * the {@code root} config key and the {@code edges} adjacency map. Produced by {@link #parse} and + * consumed when assembling {@link com.launchdarkly.sdk.server.ai.AgentGraphDefinition}. + *

+ * Parsing is intentionally defensive: malformed, missing, or wrong-typed fields never raise an + * exception. When {@code _ldMeta.enabled} is absent the default is {@code true}; when + * {@code _ldMeta.version} is absent the default is {@code 1}. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class AgentGraphFlagValue { + private static final int DEFAULT_VERSION = 1; + + private final String root; + private final Map> edges; + private final String variationKey; + private final int version; + private final boolean enabled; + + private AgentGraphFlagValue(Builder b) { + this.root = b.root; + this.edges = b.edges == null + ? Collections.>emptyMap() + : Collections.unmodifiableMap(b.edges); + this.variationKey = b.variationKey; + this.version = b.version; + this.enabled = b.enabled; + } + + /** + * Returns a disabled flag value with empty root and no edges. Used when the raw flag value is + * not a JSON object or when validation fails. + * + * @return a disabled {@link AgentGraphFlagValue} + */ + public static AgentGraphFlagValue disabled() { + return new Builder().enabled(false).build(); + } + + /** + * Parses a raw {@link LDValue} flag variation into a strongly-typed {@link AgentGraphFlagValue}. + *

+ * Returns {@link #disabled()} when {@code value} is not a JSON object. + * + * @param value the raw flag value; may be {@code null} or any JSON type + * @return the parsed representation; never {@code null} + */ + public static AgentGraphFlagValue parse(LDValue value) { + if (value == null || value.getType() != LDValueType.OBJECT) { + return disabled(); + } + + Builder builder = new Builder(); // defaults: enabled=true, version=1 + + LDValue meta = value.get("_ldMeta"); + if (meta != null && meta.getType() == LDValueType.OBJECT) { + LDValue enabledVal = meta.get("enabled"); + if (enabledVal.getType() == LDValueType.BOOLEAN) { + builder.enabled(enabledVal.booleanValue()); + } + LDValue variationKeyVal = meta.get("variationKey"); + if (variationKeyVal.getType() == LDValueType.STRING) { + builder.variationKey(variationKeyVal.stringValue()); + } + LDValue versionVal = meta.get("version"); + if (versionVal.getType() == LDValueType.NUMBER) { + builder.version(versionVal.intValue()); + } + } + + LDValue rootVal = value.get("root"); + if (rootVal.getType() == LDValueType.STRING) { + builder.root(rootVal.stringValue()); + } + + LDValue edgesVal = value.get("edges"); + if (edgesVal.getType() == LDValueType.OBJECT) { + Map> edges = new LinkedHashMap<>(); + for (String sourceKey : edgesVal.keys()) { + LDValue edgeArray = edgesVal.get(sourceKey); + if (edgeArray.getType() != LDValueType.ARRAY) { + continue; + } + List edgeList = new ArrayList<>(); + for (LDValue edgeObj : edgeArray.values()) { + if (edgeObj == null || edgeObj.getType() != LDValueType.OBJECT) { + continue; + } + LDValue keyVal = edgeObj.get("key"); + if (keyVal.getType() != LDValueType.STRING) { + continue; + } + String targetKey = keyVal.stringValue(); + Map handoff = parseHandoff(edgeObj.get("handoff")); + edgeList.add(new GraphEdge(targetKey, handoff)); + } + edges.put(sourceKey, Collections.unmodifiableList(edgeList)); + } + builder.edges(edges); + } + + return builder.build(); + } + + private static Map parseHandoff(LDValue handoff) { + if (handoff == null || handoff.getType() != LDValueType.OBJECT) { + return null; + } + Map result = new LinkedHashMap<>(); + for (String key : handoff.keys()) { + result.put(key, handoff.get(key)); + } + return result.isEmpty() ? null : Collections.unmodifiableMap(result); + } + + /** + * Returns the root node's AI Config key. + * + * @return the root key; never {@code null}, but may be empty when not specified + */ + public String getRoot() { + return root; + } + + /** + * Returns the adjacency map of outgoing edges keyed by source node config key. + * + * @return an unmodifiable map; never {@code null} but may be empty + */ + public Map> getEdges() { + return edges; + } + + /** + * Returns the {@code _ldMeta.variationKey}. + * + * @return the variation key, or {@code null} if absent + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the {@code _ldMeta.version}, defaulting to 1 when absent. + * + * @return the version; always >= 1 + */ + public int getVersion() { + return version; + } + + /** + * Returns {@code true} if the graph is enabled. Defaults to {@code true} when + * {@code _ldMeta.enabled} is absent. + * + * @return whether the graph is enabled + */ + public boolean isEnabled() { + return enabled; + } + + static Builder builder() { + return new Builder(); + } + + static final class Builder { + private String root = ""; + private Map> edges; + private String variationKey; + private int version = DEFAULT_VERSION; + private boolean enabled = true; + + private Builder() { + } + + Builder root(String v) { + this.root = v; + return this; + } + + Builder edges(Map> v) { + this.edges = v; + return this; + } + + Builder variationKey(String v) { + this.variationKey = v; + return this; + } + + Builder version(int v) { + this.version = v; + return this; + } + + Builder enabled(boolean v) { + this.enabled = v; + return this; + } + + AgentGraphFlagValue build() { + return new AgentGraphFlagValue(this); + } + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index ed15c16a..dc10817c 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -13,7 +13,7 @@ *

* This class is an internal implementation detail and is not part of the supported API. */ -final class ResumptionTokens { +public final class ResumptionTokens { private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); @@ -146,10 +146,10 @@ private static Decoded parseJson(String json) { } } - if (runId == null || runId.isEmpty()) { + if (runId == null || runId.trim().isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'runId'"); } - if (configKey == null || configKey.isEmpty()) { + if (configKey == null || configKey.trim().isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'configKey'"); } if (version == null) { @@ -249,6 +249,163 @@ static String escapeJson(String s) { return sb.toString(); } + /** + * Encodes a graph-level resumption token from the given graph identity fields. + *

+ * Field order in the JSON: {@code runId}, {@code graphKey}, {@code variationKey} (omitted if + * {@code null}), {@code version}. + * + * @param runId the run ID + * @param graphKey the agent graph key + * @param variationKey the variation key, or {@code null} to omit + * @param version the graph version + * @return the URL-safe Base64-encoded token + */ + public static String encodeGraph(String runId, String graphKey, String variationKey, int version) { + StringBuilder sb = new StringBuilder(); + sb.append("{\"runId\":\"").append(escapeJson(runId)).append('"'); + sb.append(",\"graphKey\":\"").append(escapeJson(graphKey)).append('"'); + if (variationKey != null) { + sb.append(",\"variationKey\":\"").append(escapeJson(variationKey)).append('"'); + } + sb.append(",\"version\":").append(version); + sb.append('}'); + return ENCODER.encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Decodes a graph resumption token previously produced by {@link #encodeGraph}. + * + * @param token the URL-safe Base64 token + * @return the decoded fields + * @throws IllegalArgumentException if the token is malformed, oversized, or missing required fields + */ + public static DecodedGraph decodeGraph(String token) { + if (token == null) { + throw new IllegalArgumentException("Graph resumption token must not be null"); + } + String json; + try { + byte[] bytes = DECODER.decode(token); + json = new String(bytes, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Graph resumption token is not valid Base64: " + e.getMessage(), e); + } + + return parseGraphJson(json); + } + + private static DecodedGraph parseGraphJson(String json) { + json = json.trim(); + if (!json.startsWith("{") || !json.endsWith("}")) { + throw new IllegalArgumentException("Graph resumption token JSON must be an object"); + } + + String runId = null; + String graphKey = null; + String variationKey = null; + Integer version = null; + + int pos = 1; + while (pos < json.length() - 1) { + pos = skipWhitespace(json, pos); + if (pos >= json.length() - 1) { + break; + } + if (json.charAt(pos) == ',') { + pos++; + pos = skipWhitespace(json, pos); + } + if (pos >= json.length() - 1) { + break; + } + + if (json.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos + " in graph resumption token"); + } + int[] keyEnd = new int[1]; + String key = readString(json, pos, keyEnd); + pos = keyEnd[0]; + + pos = skipWhitespace(json, pos); + if (pos >= json.length() || json.charAt(pos) != ':') { + throw new IllegalArgumentException("Expected ':' after key in graph resumption token"); + } + pos++; + pos = skipWhitespace(json, pos); + + if (json.charAt(pos) == '"') { + int[] valEnd = new int[1]; + String value = readString(json, pos, valEnd); + pos = valEnd[0]; + switch (key) { + case "runId": runId = value; break; + case "graphKey": graphKey = value; break; + case "variationKey": variationKey = value; break; + default: break; + } + } else { + int start = pos; + while (pos < json.length() && json.charAt(pos) != ',' && json.charAt(pos) != '}') { + pos++; + } + String numStr = json.substring(start, pos).trim(); + if ("version".equals(key)) { + try { + version = Integer.parseInt(numStr); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Field 'version' must be an integer in graph resumption token", e); + } + } + } + } + + if (runId == null || runId.trim().isEmpty()) { + throw new IllegalArgumentException("Graph resumption token missing required field 'runId'"); + } + if (graphKey == null || graphKey.trim().isEmpty()) { + throw new IllegalArgumentException("Graph resumption token missing required field 'graphKey'"); + } + if (version == null) { + throw new IllegalArgumentException("Graph resumption token missing required field 'version'"); + } + + return new DecodedGraph(runId, graphKey, variationKey, version); + } + + /** + * The decoded fields from a graph resumption token. + */ + public static final class DecodedGraph { + private final String runId; + private final String graphKey; + private final String variationKey; + private final int version; + + public DecodedGraph(String runId, String graphKey, String variationKey, int version) { + this.runId = runId; + this.graphKey = graphKey; + this.variationKey = variationKey; + this.version = version; + } + + public String getRunId() { + return runId; + } + + public String getGraphKey() { + return graphKey; + } + + public String getVariationKey() { + return variationKey; + } + + public int getVersion() { + return version; + } + } + /** * The decoded fields from a resumption token. */ diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java new file mode 100644 index 00000000..a35a4119 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java @@ -0,0 +1,505 @@ +package com.launchdarkly.sdk.server.ai; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.launchdarkly.logging.LDLogLevel; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LogCapture; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +@SuppressWarnings("javadoc") +public class AIGraphTrackerTest { + private LDClientInterface client; + private LogCapture logCapture; + private LDLogger logger; + private AIGraphTracker tracker; + + private static final LDContext CONTEXT = LDContext.create("user-key"); + private static final String RUN_ID = "test-run-id"; + private static final String GRAPH_KEY = "my-graph"; + private static final String VARIATION_KEY = "var-abc"; + private static final int VERSION = 2; + + @Before + public void setUp() { + client = mock(LDClientInterface.class); + logCapture = Logs.capture(); + logger = LDLogger.withAdapter(logCapture, "test"); + tracker = makeTracker(VARIATION_KEY); + } + + private AIGraphTracker makeTracker(String variationKey) { + return new AIGraphTracker(client, RUN_ID, GRAPH_KEY, variationKey, VERSION, CONTEXT, logger); + } + + private List warnings() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private List debugs() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.DEBUG) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private LDValue baseExpectedData() { + return LDValue.buildObject() + .put("runId", RUN_ID) + .put("graphKey", GRAPH_KEY) + .put("variationKey", VARIATION_KEY) + .put("version", VERSION) + .build(); + } + + // ---- trackInvocationSuccess ----------------------------------------------- + + @Test + public void trackInvocationSuccessEmitsCorrectEvent() { + tracker.trackInvocationSuccess(); + verify(client).trackMetric( + eq("$ld:ai:graph:invocation_success"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackInvocationSuccessIsAtMostOnce() { + tracker.trackInvocationSuccess(); + tracker.trackInvocationSuccess(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + } + + // ---- trackInvocationFailure ----------------------------------------------- + + @Test + public void trackInvocationFailureEmitsCorrectEvent() { + tracker.trackInvocationFailure(); + verify(client).trackMetric( + eq("$ld:ai:graph:invocation_failure"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackInvocationFailureIsAtMostOnce() { + tracker.trackInvocationFailure(); + tracker.trackInvocationFailure(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + } + + @Test + public void successAndFailureShareGuard_successFirst() { + tracker.trackInvocationSuccess(); + tracker.trackInvocationFailure(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + } + + @Test + public void successAndFailureShareGuard_failureFirst() { + tracker.trackInvocationFailure(); + tracker.trackInvocationSuccess(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + } + + // ---- trackDuration -------------------------------------------------------- + + @Test + public void trackDurationEmitsCorrectEvent() { + tracker.trackDuration(250.0); + verify(client).trackMetric( + eq("$ld:ai:graph:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(250.0)); + } + + @Test + public void trackDurationIsAtMostOnce() { + tracker.trackDuration(100.0); + tracker.trackDuration(200.0); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:duration:total"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("duration already recorded")), is(true)); + } + + // ---- trackTotalTokens ----------------------------------------------------- + + @Test + public void trackTotalTokensEmitsCorrectEvent() { + TokenUsage tokens = new TokenUsage(30, 20, 10); + tracker.trackTotalTokens(tokens); + verify(client).trackMetric( + eq("$ld:ai:graph:total_tokens"), eq(CONTEXT), eq(baseExpectedData()), eq(30.0)); + } + + @Test + public void trackTotalTokensIsAtMostOnce() { + tracker.trackTotalTokens(new TokenUsage(10, 5, 5)); + tracker.trackTotalTokens(new TokenUsage(20, 10, 10)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("token usage already recorded")), is(true)); + } + + @Test + public void trackTotalTokensAllZeroDoesNotBurnSlot() { + tracker.trackTotalTokens(new TokenUsage(0, 0, 0)); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + // Slot not consumed — a subsequent non-zero call should fire + tracker.trackTotalTokens(new TokenUsage(5, 5, 0)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + } + + @Test + public void trackTotalTokensNullIsIgnored() { + tracker.trackTotalTokens(null); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + assertThat(debugs().stream().anyMatch(w -> w.contains("tokens was null")), is(true)); + } + + // ---- trackPath ------------------------------------------------------------ + + @Test + public void trackPathEmitsCorrectEvent() { + List path = Arrays.asList("node-a", "node-b", "node-c"); + tracker.trackPath(path); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:path"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("graphKey").stringValue(), is(GRAPH_KEY)); + assertThat(data.get("path").size(), is(3)); + assertThat(data.get("path").get(0).stringValue(), is("node-a")); + assertThat(data.get("path").get(1).stringValue(), is("node-b")); + assertThat(data.get("path").get(2).stringValue(), is("node-c")); + } + + @Test + public void trackPathIsAtMostOnce() { + tracker.trackPath(Arrays.asList("node-a", "node-b")); + tracker.trackPath(Arrays.asList("node-c")); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("path already recorded")), is(true)); + } + + @Test + public void trackPathNullOrEmptyIsIgnored() { + tracker.trackPath(null); + tracker.trackPath(Arrays.asList()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + // Slot not consumed — a valid path should still fire + tracker.trackPath(Arrays.asList("node-a")); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + } + + // ---- trackRedirect -------------------------------------------------------- + + @Test + public void trackRedirectEmitsCorrectEvent() { + tracker.trackRedirect("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:redirect"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("redirectedTarget").stringValue(), is("target-b")); + assertThat(data.get("graphKey").stringValue(), is(GRAPH_KEY)); + } + + @Test + public void trackRedirectIsMultiFire() { + tracker.trackRedirect("source-a", "target-b"); + tracker.trackRedirect("source-a", "target-c"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:redirect"), any(), any(), anyDouble()); + } + + // ---- trackHandoffSuccess -------------------------------------------------- + + @Test + public void trackHandoffSuccessEmitsCorrectEvent() { + tracker.trackHandoffSuccess("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:handoff_success"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("targetKey").stringValue(), is("target-b")); + } + + @Test + public void trackHandoffSuccessIsMultiFire() { + tracker.trackHandoffSuccess("source-a", "target-b"); + tracker.trackHandoffSuccess("source-b", "target-c"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:handoff_success"), any(), any(), anyDouble()); + } + + // ---- trackHandoffFailure -------------------------------------------------- + + @Test + public void trackHandoffFailureEmitsCorrectEvent() { + tracker.trackHandoffFailure("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:handoff_failure"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("targetKey").stringValue(), is("target-b")); + } + + @Test + public void trackHandoffFailureIsMultiFire() { + tracker.trackHandoffFailure("source-a", "target-b"); + tracker.trackHandoffFailure("source-c", "target-d"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:handoff_failure"), any(), any(), anyDouble()); + } + + // ---- getSummary ----------------------------------------------------------- + + @Test + public void getSummaryNullWhenNothingTracked() { + AIGraphMetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(nullValue())); + assertThat(summary.getDurationMs(), is(nullValue())); + assertThat(summary.getTokens(), is(nullValue())); + assertThat(summary.getPath(), is(nullValue())); + assertThat(summary.getResumptionToken(), is(notNullValue())); + } + + @Test + public void getSummaryReflectsTrackedValues() { + tracker.trackInvocationSuccess(); + tracker.trackDuration(150.0); + tracker.trackTotalTokens(new TokenUsage(40, 25, 15)); + tracker.trackPath(Arrays.asList("a", "b", "c")); + + AIGraphMetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(Boolean.TRUE)); + assertThat(summary.getDurationMs(), is(150.0)); + assertThat(summary.getTokens().getTotal(), is(40L)); + assertThat(summary.getPath(), is(Arrays.asList("a", "b", "c"))); + assertThat(summary.getResumptionToken(), is(tracker.getResumptionToken())); + } + + @Test + public void getSummarySuccessIsFalseWhenFailureTracked() { + tracker.trackInvocationFailure(); + assertThat(tracker.getSummary().getSuccess(), is(Boolean.FALSE)); + } + + // ---- variationKey in track data ------------------------------------------ + + @Test + public void variationKeyIncludedWhenPresent() { + tracker.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("variationKey").stringValue(), is(VARIATION_KEY)); + } + + @Test + public void variationKeyOmittedWhenNull() { + AIGraphTracker t = makeTracker(null); + t.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("variationKey").isNull(), is(true)); + } + + // ---- resumption token ---------------------------------------------------- + + @Test + public void getResumptionTokenIsNotNull() { + assertThat(tracker.getResumptionToken(), is(notNullValue())); + } + + @Test + public void fromResumptionTokenRoundTrips() { + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + assertThat(reconstructed.getResumptionToken(), is(token)); + } + + @Test + public void fromResumptionTokenPreservesRunId() { + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + // Verify same events are emitted by the reconstructed tracker + reconstructed.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("runId").stringValue(), is(RUN_ID)); + assertThat(captor.getValue().get("graphKey").stringValue(), is(GRAPH_KEY)); + } + + @Test + public void fromResumptionTokenWithLoggerRoutesWarningsThroughIt() { + LogCapture captureForResumed = Logs.capture(); + LDLogger resumedLogger = LDLogger.withAdapter(captureForResumed, "test"); + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT, resumedLogger); + + reconstructed.trackInvocationSuccess(); + reconstructed.trackInvocationSuccess(); // duplicate — should warn on resumedLogger, not the base logCapture + + List resumedWarnings = captureForResumed.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + assertThat(resumedWarnings.stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + assertThat(logCapture.getMessages(), is(org.hamcrest.Matchers.empty())); + } + + @Test + public void fromResumptionTokenClampsVersionLessThanOne() { + AIGraphTracker t = new AIGraphTracker(client, RUN_ID, GRAPH_KEY, null, 0, CONTEXT, logger); + // Version 0 → token contains 0, but fromResumptionToken should clamp to 1 + // Actually: the tracker stores version as-is, but fromResumptionToken clamps + // Encode a token with version = 0 manually: + String token = com.launchdarkly.sdk.server.ai.internal.ResumptionTokens.encodeGraph( + RUN_ID, GRAPH_KEY, null, 0); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + reconstructed.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("version").intValue(), is(1)); + } + + // ---- constructor null checks --------------------------------------------- + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullClient() { + new AIGraphTracker(null, RUN_ID, GRAPH_KEY, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullRunId() { + new AIGraphTracker(client, null, GRAPH_KEY, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullGraphKey() { + new AIGraphTracker(client, RUN_ID, null, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullContext() { + new AIGraphTracker(client, RUN_ID, GRAPH_KEY, VARIATION_KEY, VERSION, null, logger); + } + + // ---- concurrency: at-most-once under contention ------------------------- + + @Test + public void trackInvocationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + final int idx = i; + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + if (idx % 2 == 0) { + tracker.trackInvocationSuccess(); + } else { + tracker.trackInvocationFailure(); + } + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + // Exactly one invocation event fires total + int successCount = 0, failureCount = 0; + try { + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + successCount = 1; + } catch (AssertionError ignored) {} + try { + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + failureCount = 1; + } catch (AssertionError ignored) {} + assertThat(successCount + failureCount, is(1)); + } + + @Test + public void trackDurationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackDuration(100.0); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:duration:total"), any(), any(), anyDouble()); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java new file mode 100644 index 00000000..938f198c --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java @@ -0,0 +1,445 @@ +package com.launchdarkly.sdk.server.ai; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +import com.launchdarkly.sdk.ArrayBuilder; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class AgentGraphDefinitionTest { + + // ---- helpers -------------------------------------------------------------- + + private static AgentGraphFlagValue flagValue(String root, String[][] edges) { + ObjectBuilder edgesObj = LDValue.buildObject(); + if (edges != null) { + Map> adj = new LinkedHashMap<>(); + for (String[] edge : edges) { + if (!adj.containsKey(edge[0])) { + adj.put(edge[0], new ArrayList<>()); + } + adj.get(edge[0]).add(edge[1]); + } + for (Map.Entry> entry : adj.entrySet()) { + ArrayBuilder arr = LDValue.buildArray(); + for (String target : entry.getValue()) { + arr.add(LDValue.buildObject().put("key", target).build()); + } + edgesObj.put(entry.getKey(), arr.build()); + } + } + LDValue value = LDValue.buildObject() + .put("root", root) + .put("edges", edgesObj.build()) + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .put("version", 1) + .build()) + .build(); + return AgentGraphFlagValue.parse(value); + } + + private static AIAgentConfig makeConfig(String key, boolean enabled) { + return new AIAgentConfig(key, enabled, null, null, null, null, null, + () -> mock(com.launchdarkly.sdk.server.ai.LDAIConfigTracker.class)); + } + + private static Map configs(String... keys) { + Map m = new HashMap<>(); + for (String key : keys) { + m.put(key, makeConfig(key, true)); + } + return m; + } + + private AgentGraphDefinition buildEnabled(String root, String[][] edges, String... nodeKeys) { + AgentGraphFlagValue fv = flagValue(root, edges); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs(nodeKeys)); + return new AgentGraphDefinition(fv, nodes, true, null); + } + + // ---- collectAllKeys ------------------------------------------------------- + + @Test + public void collectAllKeysIncludesRoot() { + AgentGraphFlagValue fv = flagValue("root-node", null); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys.contains("root-node"), is(true)); + } + + @Test + public void collectAllKeysIncludesEdgeSourcesAndTargets() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "c"}}); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, containsInAnyOrder("a", "b", "c")); + } + + @Test + public void collectAllKeysWithNoEdges() { + AgentGraphFlagValue fv = flagValue("solo", null); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, containsInAnyOrder("solo")); + } + + @Test + public void collectAllKeysEmptyRootIsExcluded() { + AgentGraphFlagValue fv = AgentGraphFlagValue.disabled(); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, is(empty())); + } + + // ---- buildNodes ----------------------------------------------------------- + + @Test + public void buildNodesCreatesCorrectNodeMap() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}}); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a", "b")); + assertThat(nodes.size(), is(2)); + assertThat(nodes.get("a").getKey(), is("a")); + assertThat(nodes.get("b").getKey(), is("b")); + } + + @Test + public void buildNodesAttachesEdgesToNodes() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"a", "c"}}); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a", "b", "c")); + List edges = nodes.get("a").getEdges(); + assertThat(edges.size(), is(2)); + } + + @Test + public void buildNodesSkipsMissingConfigs() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}}); + // Only provide config for "a", not "b" + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a")); + assertThat(nodes.size(), is(1)); + assertThat(nodes.containsKey("a"), is(true)); + assertThat(nodes.containsKey("b"), is(false)); + } + + // ---- rootNode / getNode -------------------------------------------------- + + @Test + public void rootNodeReturnsCorrectNode() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + AgentGraphNode root = graph.rootNode(); + assertThat(root, is(notNullValue())); + assertThat(root.getKey(), is("a")); + } + + @Test + public void rootNodeReturnsNullWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(graph.rootNode(), is(nullValue())); + } + + @Test + public void getNodeReturnsCorrectNode() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getNode("b").getKey(), is("b")); + } + + @Test + public void getNodeReturnsNullForUnknownKey() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.getNode("not-here"), is(nullValue())); + } + + // ---- isTerminal ---------------------------------------------------------- + + @Test + public void terminalNodeHasNoEdges() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getNode("b").isTerminal(), is(true)); + assertThat(graph.getNode("a").isTerminal(), is(false)); + } + + @Test + public void singleNodeGraphIsTerminal() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.rootNode().isTerminal(), is(true)); + } + + // ---- getChildNodes ------------------------------------------------------- + + @Test + public void getChildNodesFollowsEdges() { + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + List children = graph.getChildNodes("a"); + assertThat(children.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : children) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("b", "c")); + } + + @Test + public void getChildNodesReturnsEmptyForTerminal() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getChildNodes("b"), is(empty())); + } + + @Test + public void getChildNodesReturnsEmptyForUnknownKey() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.getChildNodes("no-such-key"), is(empty())); + } + + // ---- getParentNodes ------------------------------------------------------ + + @Test + public void getParentNodesFindsDirectParents() { + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "c"}, {"b", "c"}}, "a", "b", "c"); + List parents = graph.getParentNodes("c"); + assertThat(parents.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : parents) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("a", "b")); + } + + @Test + public void getParentNodesReturnsEmptyForRoot() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getParentNodes("a"), is(empty())); + } + + // ---- terminalNodes ------------------------------------------------------- + + @Test + public void terminalNodesReturnsAllTerminals() { + // a -> b, a -> c; b and c are terminals + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + List terminals = graph.terminalNodes(); + assertThat(terminals.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : terminals) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("b", "c")); + } + + @Test + public void terminalNodesWithSingleNodeIncludesRoot() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.terminalNodes().size(), is(1)); + assertThat(graph.terminalNodes().get(0).getKey(), is("a")); + } + + // ---- isEnabled ----------------------------------------------------------- + + @Test + public void isEnabledReflectsConstructorValue() { + AgentGraphDefinition enabled = buildEnabled("a", null, "a"); + assertThat(enabled.isEnabled(), is(true)); + + AgentGraphDefinition disabled = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(disabled.isEnabled(), is(false)); + } + + // ---- createTracker ------------------------------------------------------- + + @Test + public void createTrackerReturnsNullWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(graph.createTracker(), is(nullValue())); + } + + @Test + public void createTrackerReturnsTrackerWhenEnabled() { + LDClientInterface client = mock(LDClientInterface.class); + AgentGraphFlagValue fv = flagValue("a", null); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a")); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, + () -> new AIGraphTracker(client, "run-id", "graph-key", null, 1, + com.launchdarkly.sdk.LDContext.create("user"), + com.launchdarkly.logging.LDLogger.withAdapter( + com.launchdarkly.logging.Logs.none(), ""))); + assertThat(graph.createTracker(), is(notNullValue())); + } + + // ---- traverse ------------------------------------------------------------- + + @Test + public void traverseVisitsAllNodesFromRoot() { + // a -> b -> c + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"b", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + BiFunction, Object> fn = (node, ctx) -> { + visited.add(node.getKey()); + return node.getKey(); + }; + graph.traverse(fn, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a", "b", "c")); + // Root must be first + assertThat(visited.get(0), is("a")); + } + + @Test + public void traverseStoresResultsInContext() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + + Map ctx = new HashMap<>(); + BiFunction, Object> fn = (node, c) -> node.getKey() + "_result"; + graph.traverse(fn, ctx); + + assertThat(ctx.get("a"), is("a_result")); + assertThat(ctx.get("b"), is("b_result")); + } + + @Test + public void traverseIsNoOpWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, is(empty())); + } + + @Test + public void traverseHandlesCyclesSafely() { + // Manually build a cyclic graph: a -> b -> a + LDClientInterface client = mock(LDClientInterface.class); + Map cfgs = configs("a", "b"); + List aEdges = Collections.singletonList(new GraphEdge("b", null)); + List bEdges = Collections.singletonList(new GraphEdge("a", null)); + Map nodes = new HashMap<>(); + nodes.put("a", new AgentGraphNode("a", cfgs.get("a"), aEdges)); + nodes.put("b", new AgentGraphNode("b", cfgs.get("b"), bEdges)); + nodes = Collections.unmodifiableMap(nodes); + + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "a"}}); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, null); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited.size(), is(2)); // each node visited exactly once + } + + // ---- reverseTraverse ------------------------------------------------------ + + @Test + public void reverseTraverseProcessesRootLast() { + // a -> b -> c + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"b", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + + // c is terminal (seeded first), root "a" is last + assertThat(visited.get(visited.size() - 1), is("a")); + assertThat(visited.contains("b"), is(true)); + assertThat(visited.contains("c"), is(true)); + } + + @Test + public void reverseTraverseVisitsAllNodes() { + // a -> b, a -> c (c and b are terminals) + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a", "b", "c")); + assertThat(visited.get(visited.size() - 1), is("a")); + } + + @Test + public void reverseTraverseSingleNodeGraph() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a")); + } + + @Test + public void reverseTraverseHandlesCyclesSafely() { + Map cfgs = configs("a", "b"); + List aEdges = Collections.singletonList(new GraphEdge("b", null)); + List bEdges = Collections.singletonList(new GraphEdge("a", null)); + Map nodes = new HashMap<>(); + nodes.put("a", new AgentGraphNode("a", cfgs.get("a"), aEdges)); + nodes.put("b", new AgentGraphNode("b", cfgs.get("b"), bEdges)); + nodes = Collections.unmodifiableMap(nodes); + + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "a"}}); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, null); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // No infinite loop; in a pure cycle neither node is terminal, so no seeds are added — + // only root is processed in the final "root last" block. + assertThat(visited.size() <= 2, is(true)); + assertThat(visited.size() >= 1, is(true)); + } + + @Test + public void reverseTraverseIsNoOpWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, is(empty())); + } + + // ---- diamond graph traversal ---------------------------------------------- + + @Test + public void traverseDiamondGraph() { + // root -> a, root -> b; a -> sink, b -> sink + AgentGraphDefinition graph = buildEnabled("root", + new String[][]{{"root", "a"}, {"root", "b"}, {"a", "sink"}, {"b", "sink"}}, + "root", "a", "b", "sink"); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // root first, sink visited only once + assertThat(visited.get(0), is("root")); + assertThat(visited.size(), is(4)); + assertThat(new HashSet<>(visited).size(), is(4)); // all unique + } + + @Test + public void reverseTraverseDiamondGraph() { + AgentGraphDefinition graph = buildEnabled("root", + new String[][]{{"root", "a"}, {"root", "b"}, {"a", "sink"}, {"b", "sink"}}, + "root", "a", "b", "sink"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // root last, sink visited once + assertThat(visited.get(visited.size() - 1), is("root")); + assertThat(visited.size(), is(4)); + assertThat(new HashSet<>(visited).size(), is(4)); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java index 9adc857a..85c2b81c 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java @@ -10,9 +10,11 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -20,8 +22,10 @@ import com.launchdarkly.logging.LDLogger; import com.launchdarkly.logging.LogCapture; import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.ArrayBuilder; import com.launchdarkly.sdk.LDContext; import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model; @@ -30,6 +34,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -292,4 +297,191 @@ public void agentConfigsUsageCountExcludesNullEntries() { private static Map variables() { return new HashMap<>(); } + + // ---- agentGraph ----------------------------------------------------------- + + private static LDValue graphFlagValue(String root, boolean enabled, String variationKey, + String... edges) { + ObjectBuilder edgesObj = LDValue.buildObject(); + // edges are pairs: [source, target, source, target, ...] + Map edgeMap = new LinkedHashMap<>(); + for (int i = 0; i + 1 < edges.length; i += 2) { + String src = edges[i], tgt = edges[i + 1]; + if (!edgeMap.containsKey(src)) { + edgeMap.put(src, LDValue.buildArray()); + } + edgeMap.get(src).add(LDValue.buildObject().put("key", tgt).build()); + } + for (Map.Entry e : edgeMap.entrySet()) { + edgesObj.put(e.getKey(), e.getValue().build()); + } + ObjectBuilder meta = LDValue.buildObject() + .put("enabled", enabled) + .put("version", 1); + if (variationKey != null) { + meta.put("variationKey", variationKey); + } + return LDValue.buildObject() + .put("root", root) + .put("edges", edgesObj.build()) + .put("_ldMeta", meta.build()) + .build(); + } + + private static LDValue agentFlagValue(boolean enabled) { + return LDValue.parse("{\"_ldMeta\":{\"enabled\":" + enabled + ",\"mode\":\"agent\"}," + + "\"instructions\":\"test instructions\"}"); + } + + @Test + public void agentGraphFiresUsageEvent() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "v1", "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(true)); + + ai.agentGraph("g", context, null); + + verify(client).trackMetric( + eq("$ld:ai:usage:agent-graph"), eq(context), eq(LDValue.of("g")), eq(1.0)); + } + + @Test + public void agentGraphDoesNotFireNodeLevelUsageEvents() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null)); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + ai.agentGraph("g", context, null); + + verify(client, never()).trackMetric(eq("$ld:ai:usage:agent-config"), any(), any(), anyDouble()); + } + + @Test + public void agentGraphReturnsEnabledGraphForValidFlag() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "v1", "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(true)); + assertThat(graph.rootNode(), is(notNullValue())); + assertThat(graph.rootNode().getKey(), is("node-a")); + assertThat(graph.getNode("node-b"), is(notNullValue())); + } + + @Test + public void agentGraphReturnsDisabledWhenFlagDisabled() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", false, null)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenRootMissing() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("", true, null)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenUnreachableNode() { + // node-a -> node-b, but flag also declares node-c which is unreachable from root + LDValue flag = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .build()) + .put("node-c", LDValue.buildArray() // unreachable source + .add(LDValue.buildObject().put("key", "node-d").build()) + .build()) + .build()) + .put("_ldMeta", LDValue.buildObject().put("enabled", true).build()) + .build(); + when(client.jsonValueVariation(eq("g"), any(), any())).thenReturn(flag); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenAnyChildConfigDisabled() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null, "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(false)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledForNonObjectFlagValue() { + when(client.jsonValueVariation(eq("g"), any(), any())).thenReturn(LDValue.ofNull()); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphNoVariablesOverloadCallsThreeArgVersion() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null)); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context); + assertThat(graph.isEnabled(), is(true)); + } + + @Test + public void agentGraphChildConfigsIncludeGraphKey() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "var-1")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + assertThat(graph.isEnabled(), is(true)); + + // Verify: when a node tracker is created, graphKey is present in its track data + LDAIConfigTracker nodeTracker = graph.getNode("node-a").getConfig().createTracker(); + assertThat(nodeTracker.getTrackData().getGraphKey(), is("g")); + } + + // ---- createGraphTracker -------------------------------------------------- + + @Test + public void createGraphTrackerRoundTripsToken() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "var-1")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + AIGraphTracker original = graph.createTracker(); + assertThat(original, is(notNullValue())); + + String token = original.getResumptionToken(); + AIGraphTracker reconstructed = ai.createGraphTracker(token, context); + assertThat(reconstructed.getResumptionToken(), is(token)); + } + + @Test(expected = IllegalArgumentException.class) + public void createGraphTrackerThrowsForMalformedToken() { + ai.createGraphTracker("not-valid-base64!!!", context); + } + + private static java.util.LinkedHashMap newLinkedHashMap() { + return new java.util.LinkedHashMap<>(); + } } diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java new file mode 100644 index 00000000..cc690c21 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java @@ -0,0 +1,276 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.GraphEdge; + +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class AgentGraphFlagValueTest { + + // ---- disabled() factory --------------------------------------------------- + + @Test + public void disabledReturnsEnabledFalse() { + AgentGraphFlagValue v = AgentGraphFlagValue.disabled(); + assertThat(v.isEnabled(), is(false)); + assertThat(v.getRoot(), is("")); + assertThat(v.getEdges().isEmpty(), is(true)); + assertThat(v.getVariationKey(), is(nullValue())); + assertThat(v.getVersion(), is(1)); + } + + // ---- parse: non-object input ---------------------------------------------- + + @Test + public void parseNullReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(null); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parseStringReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.of("not-an-object")); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parseArrayReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildArray().add("x").build()); + assertThat(v.isEnabled(), is(false)); + } + + // ---- parse: defaults when fields absent ---------------------------------- + + @Test + public void parseEmptyObjectUsesDefaults() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildObject().build()); + assertThat(v.isEnabled(), is(true)); // default true for graphs + assertThat(v.getVersion(), is(1)); // default 1 + assertThat(v.getRoot(), is("")); // empty string default + assertThat(v.getEdges().isEmpty(), is(true)); + assertThat(v.getVariationKey(), is(nullValue())); + } + + // ---- parse: _ldMeta fields ----------------------------------------------- + + @Test + public void parsesEnabledFalseFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("enabled", false) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parsesEnabledTrueFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); + } + + @Test + public void parsesVariationKeyFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("variationKey", "var-xyz") + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVariationKey(), is("var-xyz")); + } + + @Test + public void parsesVersionFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("version", 5) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVersion(), is(5)); + } + + @Test + public void missingMetaVersionDefaultsToOne() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject().build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVersion(), is(1)); + } + + @Test + public void nonObjectMetaIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.of("not-an-object")) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); // fallback to default + assertThat(v.getVersion(), is(1)); + } + + // ---- parse: root --------------------------------------------------------- + + @Test + public void parsesRootString() { + LDValue value = LDValue.buildObject() + .put("root", "entry-node") + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getRoot(), is("entry-node")); + } + + @Test + public void missingRootIsEmptyString() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildObject().build()); + assertThat(v.getRoot(), is("")); + } + + @Test + public void nonStringRootIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", 42) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getRoot(), is("")); + } + + // ---- parse: edges -------------------------------------------------------- + + @Test + public void parsesEdgesMap() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + Map> edges = v.getEdges(); + assertThat(edges.containsKey("node-a"), is(true)); + List aEdges = edges.get("node-a"); + assertThat(aEdges.size(), is(1)); + assertThat(aEdges.get(0).getKey(), is("node-b")); + assertThat(aEdges.get(0).getHandoff(), is(nullValue())); + } + + @Test + public void parsesEdgeWithHandoff() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject() + .put("key", "node-b") + .put("handoff", LDValue.buildObject() + .put("someData", LDValue.of("hello")) + .build()) + .build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + GraphEdge edge = v.getEdges().get("node-a").get(0); + assertThat(edge.getKey(), is("node-b")); + assertThat(edge.getHandoff(), is(notNullValue())); + assertThat(edge.getHandoff().get("someData").stringValue(), is("hello")); + } + + @Test + public void edgeMissingKeyIsSkipped() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("notKey", "node-b").build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().get("node-a"), is(empty())); + } + + @Test + public void edgeWithNonArrayValueIsSkipped() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.of("not-an-array")) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().containsKey("node-a"), is(false)); + } + + @Test + public void nonObjectEdgesFieldIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.of("bad")) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().isEmpty(), is(true)); + } + + // ---- parse: full round-trip ---------------------------------------------- + + @Test + public void parsesFullFlagValue() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .add(LDValue.buildObject().put("key", "node-c").build()) + .build()) + .put("node-b", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-c").build()) + .build()) + .build()) + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .put("variationKey", "var-1") + .put("version", 2) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); + assertThat(v.getRoot(), is("node-a")); + assertThat(v.getVariationKey(), is("var-1")); + assertThat(v.getVersion(), is(2)); + assertThat(v.getEdges().get("node-a").size(), is(2)); + assertThat(v.getEdges().get("node-b").size(), is(1)); + } +}