diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index 16ce751d..77031cf4 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -81,4 +81,23 @@ AIJudgeConfig judgeConfig( LDContext context, AIJudgeConfigDefault defaultValue, Map variables); + + /** + * Reconstructs a tracker from a resumption token, preserving the original run's identity. + *

+ * Use this when a multi-turn or streaming AI interaction spans multiple requests. The caller + * stores the resumption token from a previous tracker (via + * {@link LDAIConfigTracker#getResumptionToken()}) and passes it back here to continue tracking + * against the same run. + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. + * + * @param resumptionToken the token returned by a previous tracker; must not be {@code null} + * @param context the evaluation context for the new request; must not be {@code null} + * @return a tracker with the decoded run identity, never {@code null} + * @throws IllegalArgumentException if the token is malformed + */ + LDAIConfigTracker createTracker(String resumptionToken, LDContext context); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 650fdeed..8bf81e71 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -8,13 +8,13 @@ import com.launchdarkly.sdk.LDContext; import com.launchdarkly.sdk.LDValue; import com.launchdarkly.sdk.LDValueType; -import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; +import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.internal.AIConfigFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigParser; import com.launchdarkly.sdk.server.ai.internal.AISdkInfo; import com.launchdarkly.sdk.server.ai.internal.Interpolator; -import com.launchdarkly.sdk.server.ai.internal.NoOpAIConfigTracker; +import com.launchdarkly.sdk.server.ai.internal.LDAIConfigTrackerImpl; import com.launchdarkly.sdk.server.interfaces.LDClientInterface; import java.util.ArrayList; @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.UUID; import java.util.function.Supplier; /** @@ -51,8 +52,6 @@ public final class LDAIClientImpl implements LDAIClient { .anonymous(true) .build(); - // Tracking is implemented in a later step; until then every config hands out the no-op tracker. - private static final Supplier TRACKER_FACTORY = () -> NoOpAIConfigTracker.INSTANCE; private final LDClientInterface client; private final LDLogger logger; @@ -187,6 +186,9 @@ private AIConfig buildConfig( AIConfigFlagValue parsed, LDContext context, Map variables) { + Supplier factory = trackerFactory( + key, parsed.getVariationKey(), parsed.getVersion(), + parsed.getModel(), parsed.getProvider(), context); switch (mode) { case AGENT: return new AIAgentConfig( @@ -197,7 +199,7 @@ private AIConfig buildConfig( interpolate(parsed.getInstructions(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + factory); case JUDGE: return new AIJudgeConfig( key, @@ -206,7 +208,7 @@ private AIConfig buildConfig( parsed.getProvider(), interpolateMessages(parsed.getMessages(), variables, context), parsed.getEvaluationMetricKey(), - TRACKER_FACTORY); + factory); case COMPLETION: default: return new AICompletionConfig( @@ -217,7 +219,7 @@ private AIConfig buildConfig( interpolateMessages(parsed.getMessages(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + factory); } } @@ -231,6 +233,9 @@ private AIConfig buildConfigFromDefault( AIConfigDefault defaultValue, LDContext context, Map variables) { + // Default configs still get real trackers — the configKey was requested even if no flag was found. + // variationKey is null because no flag evaluation occurred. + Supplier factory = trackerFactory(key, null, null, null, null, context); switch (mode) { case AGENT: { AIAgentConfigDefault agent = (AIAgentConfigDefault) defaultValue; @@ -242,7 +247,7 @@ private AIConfig buildConfigFromDefault( interpolate(agent.getInstructions(), variables, context), agent.getJudgeConfiguration(), agent.getTools(), - TRACKER_FACTORY); + factory); } case JUDGE: { AIJudgeConfigDefault judge = (AIJudgeConfigDefault) defaultValue; @@ -253,7 +258,7 @@ private AIConfig buildConfigFromDefault( judge.getProvider(), interpolateMessages(judge.getMessages(), variables, context), judge.getEvaluationMetricKey(), - TRACKER_FACTORY); + factory); } case COMPLETION: default: { @@ -266,11 +271,43 @@ private AIConfig buildConfigFromDefault( interpolateMessages(completion.getMessages(), variables, context), completion.getJudgeConfiguration(), completion.getTools(), - TRACKER_FACTORY); + factory); } } } + /** + * Creates a per-evaluation tracker factory. Each call to the returned {@link Supplier} produces + * a fresh {@link LDAIConfigTrackerImpl} with a new {@code runId}. + */ + private Supplier trackerFactory( + String configKey, + String variationKey, + Integer version, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, + LDContext context) { + String modelName = model != null && model.getName() != null ? model.getName() : ""; + String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; + int ver = version != null ? version : 1; + return () -> new LDAIConfigTrackerImpl( + client, + UUID.randomUUID().toString(), + configKey, + variationKey, + ver, + modelName, + providerName, + context, + null, // graphKey — set by agentGraph() in Plan 3 + logger); + } + + @Override + public LDAIConfigTracker createTracker(String resumptionToken, LDContext context) { + return LDAIConfigTrackerImpl.fromResumptionToken(resumptionToken, client, context, logger); + } + private List interpolateMessages( List messages, Map variables, LDContext context) { if (messages == null) { diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java index a298e33b..3f591a2c 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java @@ -1,16 +1,169 @@ package com.launchdarkly.sdk.server.ai; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.Function; + /** * Reports events related to a single AI run of an {@link AIConfig}. *

- * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}. Each tracker - * corresponds to one AI run and is used to record metrics such as model usage, duration, and - * feedback against the AI Config it was created from. + * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}, or + * reconstructed from a resumption token via {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)}. + * Each tracker corresponds to one AI run and is used to record metrics such as model usage, + * duration, and feedback against the AI Config it was created from. + *

+ * Most tracking methods are at-most-once: a second call to the same method on the same tracker + * is silently dropped. {@link #trackToolCall(String)} and {@link #trackJudgeResult(JudgeResult)} + * are multi-fire — each call records a distinct event. *

- * This interface is an intentional placeholder. The metric- and feedback-reporting - * methods (and resumption-token support) are introduced in a later step of the AI SDK build-out; it - * is defined here so that the public config types expose a stable {@code createTracker()} surface. - * The only implementation in this release is an internal no-op. + * Implementations are thread-safe. */ public interface LDAIConfigTracker { + + /** + * Returns the correlation metadata for this tracker's run. + * + * @return the track data, never {@code null} + */ + TrackData getTrackData(); + + /** + * Returns the resumption token for this run. + *

+ * The resumption token encodes the run's identity and can be passed to + * {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)} to reconstruct a + * tracker on a subsequent request (for example, in a streaming scenario). + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. + * + * @return the resumption token, or {@code null} if not available + */ + String getResumptionToken(); + + /** + * Records the duration of the AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param duration the duration; ignored if {@code null} + */ + void trackDuration(Duration duration); + + /** + * Executes the given operation and records its wall-clock duration. + *

+ * The duration is recorded even if the operation throws. Equivalent to wrapping the operation + * in a try/finally that calls {@link #trackDuration(Duration)}. + * + * @param the return type of the operation + * @param operation the operation to execute and time; must not be {@code null} + * @return the result of the operation + * @throws Exception if the operation throws + */ + T trackDurationOf(Callable operation) throws Exception; + + /** + * Records the time from request start to receipt of the first token. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param duration the time to first token; ignored if {@code null} + */ + void trackTimeToFirstToken(Duration duration); + + /** + * Records that the AI generation succeeded. + *

+ * At-most-once and mutually exclusive with {@link #trackError()}: whichever is called first wins. + */ + void trackSuccess(); + + /** + * Records that the AI generation failed. + *

+ * At-most-once and mutually exclusive with {@link #trackSuccess()}: whichever is called first wins. + */ + void trackError(); + + /** + * Records user feedback for this AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param kind the feedback kind; ignored if {@code null} + */ + void trackFeedback(FeedbackKind kind); + + /** + * Records token usage for this AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. Calls where all + * counts are zero do not consume the at-most-once slot. + * + * @param tokens the token usage; ignored if {@code null} + */ + void trackTokens(TokenUsage tokens); + + /** + * Records a single tool call made during this AI generation. + *

+ * Multi-fire: every call emits an event. + * + * @param toolKey the tool key; ignored if {@code null} + */ + void trackToolCall(String toolKey); + + /** + * Records multiple tool calls made during this AI generation. + *

+ * Equivalent to calling {@link #trackToolCall(String)} for each key. + * + * @param toolKeys the tool keys; ignored if {@code null} + */ + void trackToolCalls(List toolKeys); + + /** + * Records the result of a judge evaluation. + *

+ * Multi-fire per judge metric key. The result is silently skipped if it was not sampled, if + * the evaluation did not succeed, or if the metric key or score is absent. + * + * @param result the judge result; ignored if {@code null} + */ + void trackJudgeResult(JudgeResult result); + + /** + * Executes the given operation and tracks its metrics using the extracted {@link AIMetrics}. + *

+ * Tracks duration (preferring runner-reported duration when present), success or error, tokens, + * and tool calls. If the operation throws, {@link #trackError()} is called and the exception + * is re-thrown. + * + * @param the return type of the operation + * @param metricsExtractor a function that extracts {@link AIMetrics} from the operation result; + * exceptions from the extractor propagate to the caller + * @param operation the AI operation to execute; must not be {@code null} + * @return the result of the operation + * @throws Exception if the operation or the metrics extractor throws + */ + T trackMetricsOf( + Function metricsExtractor, + Callable operation) throws Exception; + + /** + * Returns a snapshot of all metrics tracked so far on this tracker. + * + * @return the metric summary, never {@code null} + */ + MetricSummary getSummary(); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java new file mode 100644 index 00000000..fbae0264 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -0,0 +1,718 @@ +package com.launchdarkly.sdk.server.ai.datamodel; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Container for the shared, immutable AI tracking data types. + *

+ * These shapes ({@link FeedbackKind}, {@link TokenUsage}, {@link AIMetrics}, {@link JudgeResult}, + * {@link MetricSummary}, and {@link TrackData}) are used by {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker} + * and its implementations to report AI run metrics and feedback. They are grouped under this single + * type, rather than declared as separate top-level classes, to keep the package small and to free + * up generic names. + *

+ * This class is not instantiable. + */ +public final class LDAITrackingTypes { + private LDAITrackingTypes() { + } + + /** + * The kind of user feedback reported via {@code trackFeedback}. + */ + public enum FeedbackKind { + /** + * Positive (thumbs-up) feedback. + */ + POSITIVE("positive"), + + /** + * Negative (thumbs-down) feedback. + */ + NEGATIVE("negative"); + + private final String value; + + FeedbackKind(String value) { + this.value = value; + } + + /** + * Returns the wire representation of this feedback kind. + * + * @return the wire value (for example {@code "positive"}) + */ + public String getValue() { + return value; + } + } + + /** + * Token usage counts for a single AI generation. + *

+ * Instances are immutable. + */ + public static final class TokenUsage { + private final long total; + private final long input; + private final long output; + + /** + * Constructs token usage counts. + * + * @param total the total token count + * @param input the input (prompt) token count + * @param output the output (completion) token count + */ + public TokenUsage(long total, long input, long output) { + this.total = total; + this.input = input; + this.output = output; + } + + /** + * Returns the total token count. + * + * @return the total token count + */ + public long getTotal() { + return total; + } + + /** + * Returns the input (prompt) token count. + * + * @return the input token count + */ + public long getInput() { + return input; + } + + /** + * Returns the output (completion) token count. + * + * @return the output token count + */ + public long getOutput() { + return output; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TokenUsage)) { + return false; + } + TokenUsage other = (TokenUsage) o; + return total == other.total && input == other.input && output == other.output; + } + + @Override + public int hashCode() { + return Objects.hash(total, input, output); + } + + @Override + public String toString() { + return "TokenUsage{total=" + total + ", input=" + input + ", output=" + output + '}'; + } + } + + /** + * Metrics extracted from a single AI generation, used with {@code trackMetricsOf}. + *

+ * Build instances with {@link #builder()}. + *

+ * Instances are immutable. + */ + public static final class AIMetrics { + private final boolean success; + private final TokenUsage tokens; + private final List toolCalls; + private final Long durationMs; + + private AIMetrics(Builder b) { + this.success = b.success; + this.tokens = b.tokens; + this.toolCalls = b.toolCalls == null ? null : Collections.unmodifiableList(new ArrayList<>(b.toolCalls)); + this.durationMs = b.durationMs; + } + + /** + * Returns whether the AI generation succeeded. + * + * @return {@code true} if the generation succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns the token usage for this generation. + * + * @return the token usage, or {@code null} if not reported + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the tool calls made during this generation. + * + * @return an unmodifiable list of tool call keys, or {@code null} if not reported + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the duration of the AI generation in milliseconds, as reported by the runner. + *

+ * When set, {@code trackMetricsOf} uses this value instead of its own wall-clock measurement. + * + * @return the runner-reported duration in milliseconds, or {@code null} if not reported + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Creates a new builder. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link AIMetrics}. + */ + public static final class Builder { + private boolean success; + private TokenUsage tokens; + private List toolCalls; + private Long durationMs; + + private Builder() { + } + + /** + * Sets whether the AI generation succeeded. + * + * @param success {@code true} if the generation succeeded + * @return this builder + */ + public Builder success(boolean success) { + this.success = success; + return this; + } + + /** + * Sets the token usage. + * + * @param tokens the token usage; may be {@code null} + * @return this builder + */ + public Builder tokens(TokenUsage tokens) { + this.tokens = tokens; + return this; + } + + /** + * Sets the tool calls made during this generation. + * + * @param toolCalls the tool call keys; may be {@code null} + * @return this builder + */ + public Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + /** + * Sets the runner-reported duration in milliseconds. + * + * @param durationMs the duration; may be {@code null} + * @return this builder + */ + public Builder durationMs(Long durationMs) { + this.durationMs = durationMs; + return this; + } + + /** + * Builds the immutable {@link AIMetrics}. + * + * @return a new {@link AIMetrics} + */ + public AIMetrics build() { + return new AIMetrics(this); + } + } + } + + /** + * The result of a judge evaluation, reported via {@code trackJudgeResult}. + *

+ * Build instances with {@link #builder()}. + *

+ * Instances are immutable. + */ + public static final class JudgeResult { + private final String judgeConfigKey; + private final boolean success; + private final String errorMessage; + private final boolean sampled; + private final String metricKey; + private final Double score; + private final String reasoning; + + private JudgeResult(Builder b) { + this.judgeConfigKey = b.judgeConfigKey; + this.success = b.success; + this.errorMessage = b.errorMessage; + this.sampled = b.sampled; + this.metricKey = b.metricKey; + this.score = b.score; + this.reasoning = b.reasoning; + } + + /** + * Returns the key of the judge AI Config, if known. + * + * @return the judge config key, or {@code null} if not set + */ + public String getJudgeConfigKey() { + return judgeConfigKey; + } + + /** + * Returns whether the judge evaluation succeeded. + * + * @return {@code true} if the evaluation succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns an error message from the judge evaluation, if any. + * + * @return the error message, or {@code null} if none + */ + public String getErrorMessage() { + return errorMessage; + } + + /** + * Returns whether this result was selected for sampling. + * + * @return {@code true} if the result was sampled + */ + public boolean isSampled() { + return sampled; + } + + /** + * Returns the metric key to use when emitting this result. + * + * @return the metric key, or {@code null} if not set + */ + public String getMetricKey() { + return metricKey; + } + + /** + * Returns the judge score. + *

+ * A {@code null} score is distinct from a score of {@code 0.0} — a null score means no score + * was produced, while {@code 0.0} is a valid score. + * + * @return the score, or {@code null} if not set + */ + public Double getScore() { + return score; + } + + /** + * Returns the judge's reasoning, if any. + * + * @return the reasoning, or {@code null} if none + */ + public String getReasoning() { + return reasoning; + } + + /** + * Creates a new builder. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link JudgeResult}. + */ + public static final class Builder { + private String judgeConfigKey; + private boolean success; + private String errorMessage; + private boolean sampled; + private String metricKey; + private Double score; + private String reasoning; + + private Builder() { + } + + /** + * Sets the judge config key. + * + * @param judgeConfigKey the key; may be {@code null} + * @return this builder + */ + public Builder judgeConfigKey(String judgeConfigKey) { + this.judgeConfigKey = judgeConfigKey; + return this; + } + + /** + * Sets whether the judge evaluation succeeded. + * + * @param success {@code true} if succeeded + * @return this builder + */ + public Builder success(boolean success) { + this.success = success; + return this; + } + + /** + * Sets the error message. + * + * @param errorMessage the error message; may be {@code null} + * @return this builder + */ + public Builder errorMessage(String errorMessage) { + this.errorMessage = errorMessage; + return this; + } + + /** + * Sets whether this result was sampled. + * + * @param sampled {@code true} if sampled + * @return this builder + */ + public Builder sampled(boolean sampled) { + this.sampled = sampled; + return this; + } + + /** + * Sets the metric key. + * + * @param metricKey the metric key; may be {@code null}, but must not be blank if non-null + * @return this builder + * @throws IllegalArgumentException if {@code metricKey} is non-null and blank + */ + public Builder metricKey(String metricKey) { + if (metricKey != null && metricKey.trim().isEmpty()) { + throw new IllegalArgumentException("metricKey must not be blank"); + } + this.metricKey = metricKey; + return this; + } + + /** + * Sets the judge score. + * + * @param score the score; may be {@code null}, but must be finite if non-null + * @return this builder + * @throws IllegalArgumentException if {@code score} is non-null and non-finite (NaN or infinite) + */ + public Builder score(Double score) { + if (score != null && !Double.isFinite(score)) { + throw new IllegalArgumentException("score must be finite"); + } + this.score = score; + return this; + } + + /** + * Sets the reasoning. + * + * @param reasoning the reasoning; may be {@code null} + * @return this builder + */ + public Builder reasoning(String reasoning) { + this.reasoning = reasoning; + return this; + } + + /** + * Builds the immutable {@link JudgeResult}. + * + * @return a new {@link JudgeResult} + */ + public JudgeResult build() { + return new JudgeResult(this); + } + } + } + + /** + * A snapshot of all metrics tracked by a single {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker}. + *

+ * Returned by {@code getSummary()}. All fields are nullable — {@code null} indicates the + * corresponding metric has not been recorded yet. {@link #getSuccess()} is a tri-state: + * {@code null} = not yet tracked, {@code true} = success was recorded, {@code false} = error + * was recorded. + *

+ * Instances are immutable. + */ + public static final class MetricSummary { + private final Boolean success; + private final TokenUsage tokens; + private final Long durationMs; + private final FeedbackKind feedback; + private final Long timeToFirstTokenMs; + private final List toolCalls; + private final String resumptionToken; + + /** + * Constructs a metric summary snapshot. + * + * @param success tri-state outcome: {@code null} = not tracked, {@code true} = success, {@code false} = error + * @param tokens the token usage, or {@code null} + * @param durationMs the duration in milliseconds, or {@code null} + * @param feedback the feedback kind, or {@code null} + * @param timeToFirstTokenMs the time to first token in milliseconds, or {@code null} + * @param toolCalls the tool calls made, or {@code null} + * @param resumptionToken the resumption token, or {@code null} + */ + public MetricSummary( + Boolean success, + TokenUsage tokens, + Long durationMs, + FeedbackKind feedback, + Long timeToFirstTokenMs, + List toolCalls, + String resumptionToken) { + this.success = success; + this.tokens = tokens; + this.durationMs = durationMs; + this.feedback = feedback; + this.timeToFirstTokenMs = timeToFirstTokenMs; + this.toolCalls = toolCalls == null ? null : Collections.unmodifiableList(new ArrayList<>(toolCalls)); + this.resumptionToken = resumptionToken; + } + + /** + * Returns the outcome of the AI generation, as a tri-state. + * + * @return {@code null} if not tracked, {@code true} if success was recorded, {@code false} if error was recorded + */ + public Boolean getSuccess() { + return success; + } + + /** + * Returns the token usage. + * + * @return the token usage, or {@code null} if not tracked + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the duration in milliseconds. + * + * @return the duration, or {@code null} if not tracked + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Returns the feedback kind. + * + * @return the feedback, or {@code null} if not tracked + */ + public FeedbackKind getFeedback() { + return feedback; + } + + /** + * Returns the time to first token in milliseconds. + * + * @return the time to first token, or {@code null} if not tracked + */ + public Long getTimeToFirstTokenMs() { + return timeToFirstTokenMs; + } + + /** + * Returns the tool calls made during the generation. + * + * @return an unmodifiable list of tool call keys, or {@code null} if none were tracked + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the resumption token for this tracker. + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. + * + * @return the resumption token, or {@code null} if not available + */ + public String getResumptionToken() { + return resumptionToken; + } + } + + /** + * Correlation metadata attached to every metric event emitted by a tracker. + *

+ * Instances are immutable. + */ + public static final class TrackData { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String modelName; + private final String providerName; + private final String graphKey; + + /** + * Constructs track data. + * + * @param runId the unique run identifier; must not be {@code null} + * @param configKey the AI Config key; must not be {@code null} + * @param variationKey the variation key, or {@code null} when a default config is used + * @param version the config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param graphKey the agent graph key, or {@code null} when not part of a graph + */ + public TrackData( + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + String graphKey) { + this.runId = Objects.requireNonNull(runId, "runId"); + this.configKey = Objects.requireNonNull(configKey, "configKey"); + this.variationKey = variationKey; + this.version = version; + this.modelName = modelName == null ? "" : modelName; + this.providerName = providerName == null ? "" : providerName; + this.graphKey = graphKey; + } + + /** + * Returns the unique run identifier. + * + * @return the run ID, never {@code null} + */ + public String getRunId() { + return runId; + } + + /** + * Returns the AI Config key. + * + * @return the config key, never {@code null} + */ + public String getConfigKey() { + return configKey; + } + + /** + * Returns the variation key. + * + * @return the variation key, or {@code null} when a default config is used + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the config version. + * + * @return the version + */ + public int getVersion() { + return version; + } + + /** + * Returns the model name. + * + * @return the model name, or empty string when unknown + */ + public String getModelName() { + return modelName; + } + + /** + * Returns the provider name. + * + * @return the provider name, or empty string when unknown + */ + public String getProviderName() { + return providerName; + } + + /** + * Returns the agent graph key. + * + * @return the graph key, or {@code null} when not part of a graph + */ + public String getGraphKey() { + return graphKey; + } + + /** + * Builds an {@link LDValue} representation of this track data using camelCase keys. + *

+ * {@code variationKey} and {@code graphKey} are omitted when {@code null}. + * + * @return an {@link LDValue} object containing all non-null fields + */ + public LDValue toLDValue() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("configKey", configKey) + .put("version", version) + .put("modelName", modelName) + .put("providerName", providerName); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + if (graphKey != null) { + b.put("graphKey", graphKey); + } + return b.build(); + } + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java new file mode 100644 index 00000000..766d59e7 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -0,0 +1,389 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * The default {@link LDAIConfigTracker} implementation. + *

+ * Tracks AI run metrics and emits them as LaunchDarkly custom events via the wrapped + * {@link LDClientInterface}. At-most-once semantics for each metric type are enforced using + * {@link AtomicReference#compareAndSet} — a single atomic operation that serves as both guard + * and value store, eliminating the race window present in a two-step check-then-act pattern. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class LDAIConfigTrackerImpl implements LDAIConfigTracker { + + private static final String DURATION_TOTAL = "$ld:ai:duration:total"; + private static final String TOKENS_TTF = "$ld:ai:tokens:ttf"; + private static final String GENERATION_SUCCESS = "$ld:ai:generation:success"; + private static final String GENERATION_ERROR = "$ld:ai:generation:error"; + private static final String FEEDBACK_POSITIVE = "$ld:ai:feedback:user:positive"; + private static final String FEEDBACK_NEGATIVE = "$ld:ai:feedback:user:negative"; + private static final String TOKENS_TOTAL = "$ld:ai:tokens:total"; + private static final String TOKENS_INPUT = "$ld:ai:tokens:input"; + private static final String TOKENS_OUTPUT = "$ld:ai:tokens:output"; + private static final String TOOL_CALL = "$ld:ai:tool_call"; + + private final LDClientInterface client; + private final LDContext context; + private final LDLogger logger; + + // Identity fields + private final String runId; + private final String configKey; + private final String variationKey; // nullable — null when using a default config + private final int version; + private final String modelName; // empty string when unknown + private final String providerName; // empty string when unknown + private final String graphKey; // nullable + + // Computed once at construction + private final String resumptionToken; + + // At-most-once slots: null = not yet recorded, non-null = recorded with this value. + // AtomicReference.compareAndSet(null, value) is a single atomic operation — both guard and + // value store — eliminating the race window in an AtomicBoolean + volatile approach. + private final AtomicReference durationMs = new AtomicReference<>(); + private final AtomicReference timeToFirstTokenMs = new AtomicReference<>(); + // Shared by trackSuccess and trackError: true = success, false = error + private final AtomicReference outcome = new AtomicReference<>(); + private final AtomicReference feedbackRef = new AtomicReference<>(); + private final AtomicReference tokensRef = new AtomicReference<>(); + + // Multi-fire accumulator — not at-most-once + private final CopyOnWriteArrayList toolCalls = new CopyOnWriteArrayList<>(); + + /** + * Creates a tracker for a new AI run. + * + * @param client the LaunchDarkly client used to emit events; must not be {@code null} + * @param runId the unique run identifier (UUID v4); must not be {@code null} + * @param configKey the AI Config key; must not be {@code null} + * @param variationKey the variation key, or {@code null} when using a default config + * @param version the config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param context the evaluation context; must not be {@code null} + * @param graphKey the agent graph key, or {@code null} when not part of a graph + * @param logger the logger; must not be {@code null} + */ + public LDAIConfigTrackerImpl( + LDClientInterface client, + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + LDContext context, + String graphKey, + LDLogger logger) { + this.client = Objects.requireNonNull(client, "client"); + this.runId = Objects.requireNonNull(runId, "runId"); + this.configKey = Objects.requireNonNull(configKey, "configKey"); + this.variationKey = variationKey; + this.version = version; + this.modelName = modelName == null ? "" : modelName; + this.providerName = providerName == null ? "" : providerName; + this.context = Objects.requireNonNull(context, "context"); + this.graphKey = graphKey; + this.logger = Objects.requireNonNull(logger, "logger"); + + // Compute once at construction — all inputs are immutable. + this.resumptionToken = ResumptionTokens.encode(runId, configKey, variationKey, version, graphKey); + } + + /** + * Reconstructs a tracker from a resumption token, preserving the original run's identity. + * + * @param token the resumption token + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @param logger the logger; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static LDAIConfigTrackerImpl fromResumptionToken( + String token, LDClientInterface client, LDContext context, LDLogger logger) { + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + return new LDAIConfigTrackerImpl( + client, + d.getRunId(), + d.getConfigKey(), + d.getVariationKey(), + d.getVersion(), + "", // modelName not carried in token + "", // providerName not carried in token + context, + d.getGraphKey(), + logger); + } + + @Override + public TrackData getTrackData() { + return new TrackData(runId, configKey, variationKey, version, modelName, providerName, graphKey); + } + + @Override + public String getResumptionToken() { + return resumptionToken; + } + + @Override + public void trackDuration(Duration duration) { + if (duration == null) { + logger.debug("Skipping trackDuration: duration was null."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + if (!durationMs.compareAndSet(null, ms)) { + logger.warn("Skipping trackDuration: duration already recorded on this tracker."); + return; + } + client.trackMetric(DURATION_TOTAL, context, baseData().build(), ms); + } + + @Override + public T trackDurationOf(Callable operation) throws Exception { + Objects.requireNonNull(operation, "operation"); + long start = System.nanoTime(); + try { + return operation.call(); + } finally { + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + trackDuration(Duration.ofMillis(elapsedMs)); + } + } + + @Override + public void trackTimeToFirstToken(Duration duration) { + if (duration == null) { + logger.debug("Skipping trackTimeToFirstToken: duration was null."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + if (!timeToFirstTokenMs.compareAndSet(null, ms)) { + logger.warn("Skipping trackTimeToFirstToken: time-to-first-token already recorded on this tracker."); + return; + } + client.trackMetric(TOKENS_TTF, context, baseData().build(), ms); + } + + @Override + public void trackSuccess() { + if (!outcome.compareAndSet(null, Boolean.TRUE)) { + logger.warn("Skipping trackSuccess: outcome already recorded on this tracker."); + return; + } + client.trackMetric(GENERATION_SUCCESS, context, baseData().build(), 1); + } + + @Override + public void trackError() { + if (!outcome.compareAndSet(null, Boolean.FALSE)) { + logger.warn("Skipping trackError: outcome already recorded on this tracker."); + return; + } + client.trackMetric(GENERATION_ERROR, context, baseData().build(), 1); + } + + @Override + public void trackFeedback(FeedbackKind kind) { + if (kind == null) { + logger.debug("Skipping trackFeedback: kind was null."); + return; + } + // Resolve event name BEFORE claiming the guard — an exception here must not burn the slot. + String eventName = kind == FeedbackKind.POSITIVE ? FEEDBACK_POSITIVE : FEEDBACK_NEGATIVE; + if (!feedbackRef.compareAndSet(null, kind)) { + logger.warn("Skipping trackFeedback: feedback already recorded on this tracker."); + return; + } + client.trackMetric(eventName, context, baseData().build(), 1); + } + + @Override + public void trackTokens(TokenUsage tokens) { + if (tokens == null) { + logger.debug("Skipping trackTokens: tokens was null."); + return; + } + boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; + if (!hasPositive) { + // Do not burn the at-most-once slot when all counts are zero. + return; + } + if (!tokensRef.compareAndSet(null, tokens)) { + logger.warn("Skipping trackTokens: token usage already recorded on this tracker."); + return; + } + if (tokens.getTotal() > 0) { + client.trackMetric(TOKENS_TOTAL, context, baseData().build(), tokens.getTotal()); + } + if (tokens.getInput() > 0) { + client.trackMetric(TOKENS_INPUT, context, baseData().build(), tokens.getInput()); + } + if (tokens.getOutput() > 0) { + client.trackMetric(TOKENS_OUTPUT, context, baseData().build(), tokens.getOutput()); + } + } + + @Override + public void trackToolCall(String toolKey) { + if (toolKey == null) { + logger.debug("Skipping trackToolCall: toolKey was null."); + return; + } + toolCalls.add(toolKey); + LDValue data = baseData().put("toolKey", toolKey).build(); + client.trackMetric(TOOL_CALL, context, data, 1); + } + + @Override + public void trackToolCalls(List toolKeys) { + if (toolKeys == null) { + return; + } + for (String key : toolKeys) { + trackToolCall(key); + } + } + + @Override + public void trackJudgeResult(JudgeResult result) { + if (result == null) { + logger.debug("Skipping trackJudgeResult: result was null."); + return; + } + if (!result.isSampled()) { + return; + } + if (!result.isSuccess()) { + return; + } + if (result.getMetricKey() == null || result.getMetricKey().trim().isEmpty()) { + return; + } + if (result.getScore() == null || !Double.isFinite(result.getScore())) { + return; + } + ObjectBuilder data = baseData(); + if (result.getJudgeConfigKey() != null) { + data.put("judgeConfigKey", result.getJudgeConfigKey()); + } + client.trackMetric(result.getMetricKey(), context, data.build(), result.getScore()); + } + + @Override + public T trackMetricsOf( + Function metricsExtractor, + Callable operation) throws Exception { + Objects.requireNonNull(metricsExtractor, "metricsExtractor"); + Objects.requireNonNull(operation, "operation"); + + long start = System.nanoTime(); + T result; + try { + result = operation.call(); + } catch (Exception e) { + // Operation failed — track measured duration + error, then re-throw. + long elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + trackDuration(Duration.ofMillis(elapsed)); + trackError(); + throw e; + } + // Capture operation duration immediately so a slow extractor does not inflate the metric. + long operationElapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + + // Extractor exceptions propagate to the caller, but the operation's duration must still be + // recorded — the AI operation itself succeeded, only the user-supplied extractor failed. + // Do NOT call trackError(); that signals the operation failed, which is not what happened. + AIMetrics metrics; + try { + metrics = Objects.requireNonNull(metricsExtractor.apply(result), "metricsExtractor returned null"); + } catch (RuntimeException e) { + trackDuration(Duration.ofMillis(operationElapsedMs)); + throw e; + } + + // Duration: prefer runner-reported value (§1.1.13.2), fall back to wall-clock. + if (metrics.getDurationMs() != null) { + trackDuration(Duration.ofMillis(metrics.getDurationMs())); + } else { + trackDuration(Duration.ofMillis(operationElapsedMs)); + } + + if (metrics.isSuccess()) { + trackSuccess(); + } else { + trackError(); + } + + if (metrics.getTokens() != null) { + trackTokens(metrics.getTokens()); + } + if (metrics.getToolCalls() != null) { + trackToolCalls(metrics.getToolCalls()); + } + + return result; + } + + @Override + public MetricSummary getSummary() { + List snapshot = toolCalls.isEmpty() + ? null + : Collections.unmodifiableList(new ArrayList<>(toolCalls)); + return new MetricSummary( + outcome.get(), + tokensRef.get(), + durationMs.get(), + feedbackRef.get(), + timeToFirstTokenMs.get(), + snapshot, + resumptionToken); + } + + /** + * Returns a pre-populated {@link LDValue.ObjectBuilder} containing the base track-data fields. + * Individual track methods add per-event fields before calling {@link LDValue.ObjectBuilder#build()}. + */ + private ObjectBuilder baseData() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("configKey", configKey) + .put("version", version) + .put("modelName", modelName) + .put("providerName", providerName); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + if (graphKey != null) { + b.put("graphKey", graphKey); + } + return b; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java deleted file mode 100644 index 1cbc3c51..00000000 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java +++ /dev/null @@ -1,19 +0,0 @@ -package com.launchdarkly.sdk.server.ai.internal; - -import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; - -/** - * The no-op {@link LDAIConfigTracker} used until metric reporting is implemented in a later step of - * the AI SDK. It is immutable and stateless, so a single shared instance is safe to reuse. - *

- * This class is an internal implementation detail and is not part of the supported API. - */ -public final class NoOpAIConfigTracker implements LDAIConfigTracker { - /** - * The shared instance. - */ - public static final NoOpAIConfigTracker INSTANCE = new NoOpAIConfigTracker(); - - private NoOpAIConfigTracker() { - } -} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java new file mode 100644 index 00000000..ed15c16a --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -0,0 +1,290 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +/** + * Encodes and decodes resumption tokens for {@link LDAIConfigTrackerImpl}. + *

+ * A resumption token is a URL-safe Base64 (RFC 4648, no padding) encoding of a canonical JSON + * object that carries the run's identity fields. Tokens can be stored by callers and passed back + * to {@link com.launchdarkly.sdk.server.ai.LDAIClient#createTracker} to reconstruct a tracker + * across requests (for example, in a streaming or multi-turn scenario). + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +final class ResumptionTokens { + private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); + private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); + + private ResumptionTokens() { + } + + /** + * Encodes a resumption token from the given run identity fields. + *

+ * Field order in the JSON: {@code runId}, {@code configKey}, {@code variationKey} (omitted if + * {@code null}), {@code version}, {@code graphKey} (omitted if {@code null}). + * + * @param runId the run ID + * @param configKey the AI Config key + * @param variationKey the variation key, or {@code null} to omit + * @param version the config version + * @param graphKey the graph key, or {@code null} to omit + * @return the URL-safe Base64-encoded token + */ + static String encode(String runId, String configKey, String variationKey, + int version, String graphKey) { + StringBuilder sb = new StringBuilder(); + sb.append("{\"runId\":\"").append(escapeJson(runId)).append('"'); + sb.append(",\"configKey\":\"").append(escapeJson(configKey)).append('"'); + if (variationKey != null) { + sb.append(",\"variationKey\":\"").append(escapeJson(variationKey)).append('"'); + } + sb.append(",\"version\":").append(version); + if (graphKey != null) { + sb.append(",\"graphKey\":\"").append(escapeJson(graphKey)).append('"'); + } + sb.append('}'); + return ENCODER.encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Decodes a resumption token previously produced by {@link #encode}. + * + * @param token the URL-safe Base64 token + * @return the decoded fields + * @throws IllegalArgumentException if the token is malformed or missing required fields + */ + static Decoded decode(String token) { + if (token == null) { + throw new IllegalArgumentException("Resumption token must not be null"); + } + + String json; + try { + byte[] bytes = DECODER.decode(token); + json = new String(bytes, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Resumption token is not valid Base64: " + e.getMessage(), e); + } + + return parseJson(json); + } + + /** + * Minimal JSON parser for the fixed token structure. Handles only the fields we write. + */ + private static Decoded parseJson(String json) { + json = json.trim(); + if (!json.startsWith("{") || !json.endsWith("}")) { + throw new IllegalArgumentException("Resumption token JSON must be an object"); + } + + String runId = null; + String configKey = null; + String variationKey = null; + Integer version = null; + String graphKey = null; + + // Walk through the JSON object fields + int pos = 1; // skip opening '{' + while (pos < json.length() - 1) { + pos = skipWhitespace(json, pos); + if (pos >= json.length() - 1) { + break; + } + if (json.charAt(pos) == ',') { + pos++; + pos = skipWhitespace(json, pos); + } + if (pos >= json.length() - 1) { + break; + } + + // Read key + if (json.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos + " in resumption token"); + } + int[] keyEnd = new int[1]; + String key = readString(json, pos, keyEnd); + pos = keyEnd[0]; + + pos = skipWhitespace(json, pos); + if (pos >= json.length() || json.charAt(pos) != ':') { + throw new IllegalArgumentException("Expected ':' after key in resumption token"); + } + pos++; // skip ':' + pos = skipWhitespace(json, pos); + + // Read value + if (json.charAt(pos) == '"') { + int[] valEnd = new int[1]; + String value = readString(json, pos, valEnd); + pos = valEnd[0]; + switch (key) { + case "runId": runId = value; break; + case "configKey": configKey = value; break; + case "variationKey": variationKey = value; break; + case "graphKey": graphKey = value; break; + default: break; + } + } else { + // numeric value + int start = pos; + while (pos < json.length() && json.charAt(pos) != ',' && json.charAt(pos) != '}') { + pos++; + } + String numStr = json.substring(start, pos).trim(); + if ("version".equals(key)) { + try { + version = Integer.parseInt(numStr); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Field 'version' must be an integer in resumption token", e); + } + } + } + } + + if (runId == null || runId.isEmpty()) { + throw new IllegalArgumentException("Resumption token missing required field 'runId'"); + } + if (configKey == null || configKey.isEmpty()) { + throw new IllegalArgumentException("Resumption token missing required field 'configKey'"); + } + if (version == null) { + throw new IllegalArgumentException("Resumption token missing required field 'version'"); + } + + return new Decoded(runId, configKey, variationKey, version, graphKey); + } + + private static int skipWhitespace(String s, int pos) { + while (pos < s.length() && Character.isWhitespace(s.charAt(pos))) { + pos++; + } + return pos; + } + + /** + * Reads a JSON string starting at {@code pos} (which must point to the opening {@code "}). + * Populates {@code end[0]} with the position after the closing {@code "}. + */ + private static String readString(String s, int pos, int[] end) { + if (s.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos); + } + pos++; // skip opening quote + StringBuilder sb = new StringBuilder(); + while (pos < s.length()) { + char c = s.charAt(pos); + if (c == '"') { + end[0] = pos + 1; + return sb.toString(); + } else if (c == '\\') { + pos++; + if (pos >= s.length()) { + throw new IllegalArgumentException("Unterminated escape in resumption token"); + } + char escaped = s.charAt(pos); + switch (escaped) { + case '"': sb.append('"'); break; + case '\\': sb.append('\\'); break; + case '/': sb.append('/'); break; + case 'b': sb.append('\b'); break; + case 'f': sb.append('\f'); break; + case 'n': sb.append('\n'); break; + case 'r': sb.append('\r'); break; + case 't': sb.append('\t'); break; + case 'u': + if (pos + 4 >= s.length()) { + throw new IllegalArgumentException("Incomplete Unicode escape in resumption token"); + } + String hex = s.substring(pos + 1, pos + 5); + try { + sb.append((char) Integer.parseInt(hex, 16)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid Unicode escape \\u" + hex + " in resumption token"); + } + pos += 4; + break; + default: + throw new IllegalArgumentException("Unknown escape character '\\" + escaped + "' in resumption token"); + } + } else { + sb.append(c); + } + pos++; + } + throw new IllegalArgumentException("Unterminated string in resumption token"); + } + + /** + * Escapes a string for inclusion in a JSON string literal. Handles the characters required by + * RFC 8259 §7. + */ + static String escapeJson(String s) { + if (s == null) { + return ""; + } + StringBuilder sb = new StringBuilder(s.length()); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"': sb.append("\\\""); break; + case '\\': sb.append("\\\\"); break; + case '\b': sb.append("\\b"); break; + case '\f': sb.append("\\f"); break; + case '\n': sb.append("\\n"); break; + case '\r': sb.append("\\r"); break; + case '\t': sb.append("\\t"); break; + default: + if (c < 0x20) { + sb.append(String.format("\\u%04x", (int) c)); + } else { + sb.append(c); + } + } + } + return sb.toString(); + } + + /** + * The decoded fields from a resumption token. + */ + static final class Decoded { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String graphKey; + + Decoded(String runId, String configKey, String variationKey, int version, String graphKey) { + this.runId = runId; + this.configKey = configKey; + this.variationKey = variationKey; + this.version = version; + this.graphKey = graphKey; + } + + String getRunId() { + return runId; + } + + String getConfigKey() { + return configKey; + } + + String getVariationKey() { + return variationKey; + } + + int getVersion() { + return version; + } + + String getGraphKey() { + return graphKey; + } + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java new file mode 100644 index 00000000..ffcadaa9 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -0,0 +1,738 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.launchdarkly.logging.LDLogLevel; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LogCapture; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +@SuppressWarnings("javadoc") +public class LDAIConfigTrackerImplTest { + private LDClientInterface client; + private LogCapture logCapture; + private LDLogger logger; + private LDAIConfigTrackerImpl tracker; + + private static final LDContext CONTEXT = LDContext.create("user-key"); + private static final String RUN_ID = "test-run-id"; + private static final String CONFIG_KEY = "my-config"; + private static final String VARIATION_KEY = "variation-abc"; + private static final int VERSION = 3; + private static final String MODEL_NAME = "gpt-4"; + private static final String PROVIDER_NAME = "openai"; + + @Before + public void setUp() { + client = mock(LDClientInterface.class); + logCapture = Logs.capture(); + logger = LDLogger.withAdapter(logCapture, "test"); + tracker = makeTracker(VARIATION_KEY); + } + + private LDAIConfigTrackerImpl makeTracker(String variationKey) { + return new LDAIConfigTrackerImpl( + client, RUN_ID, CONFIG_KEY, variationKey, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, null, logger); + } + + private List warnings() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private List debugs() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.DEBUG) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private LDValue baseExpectedData() { + return LDValue.buildObject() + .put("runId", RUN_ID) + .put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY) + .put("version", VERSION) + .put("modelName", MODEL_NAME) + .put("providerName", PROVIDER_NAME) + .build(); + } + + // ---- getTrackData / getResumptionToken ------------------------------------ + + @Test + public void getTrackDataReturnsCorrectFields() { + TrackData data = tracker.getTrackData(); + assertThat(data.getRunId(), is(RUN_ID)); + assertThat(data.getConfigKey(), is(CONFIG_KEY)); + assertThat(data.getVariationKey(), is(VARIATION_KEY)); + assertThat(data.getVersion(), is(VERSION)); + assertThat(data.getModelName(), is(MODEL_NAME)); + assertThat(data.getProviderName(), is(PROVIDER_NAME)); + assertThat(data.getGraphKey(), is(nullValue())); + } + + @Test + public void getTrackDataOmitsVariationKeyWhenNull() { + LDAIConfigTrackerImpl t = makeTracker(null); + assertThat(t.getTrackData().getVariationKey(), is(nullValue())); + LDValue ldv = t.getTrackData().toLDValue(); + assertThat(ldv.get("variationKey").isNull(), is(true)); // absent key returns LDValue.ofNull() + } + + @Test + public void getResumptionTokenIsNotNull() { + assertThat(tracker.getResumptionToken(), is(notNullValue())); + } + + @Test + public void resumptionTokenRoundTrips() throws Exception { + String token = tracker.getResumptionToken(); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(RUN_ID)); + assertThat(d.getConfigKey(), is(CONFIG_KEY)); + assertThat(d.getVariationKey(), is(VARIATION_KEY)); + assertThat(d.getVersion(), is(VERSION)); + assertThat(d.getGraphKey(), is(nullValue())); + } + + @Test + public void fromResumptionTokenRestoresCorrectFields() { + String token = tracker.getResumptionToken(); + LDAIConfigTrackerImpl restored = + LDAIConfigTrackerImpl.fromResumptionToken(token, client, CONTEXT, logger); + TrackData data = restored.getTrackData(); + assertThat(data.getRunId(), is(RUN_ID)); + assertThat(data.getConfigKey(), is(CONFIG_KEY)); + assertThat(data.getVariationKey(), is(VARIATION_KEY)); + assertThat(data.getVersion(), is(VERSION)); + assertThat(data.getModelName(), is("")); // not in token + assertThat(data.getProviderName(), is("")); // not in token + } + + // ---- trackDuration -------------------------------------------------------- + + @Test + public void trackDurationEmitsCorrectEvent() { + tracker.trackDuration(Duration.ofMillis(500)); + verify(client).trackMetric( + eq("$ld:ai:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(500.0)); + } + + @Test + public void trackDurationClampsNegativeToZero() { + tracker.trackDuration(Duration.ofMillis(-100)); + verify(client).trackMetric( + eq("$ld:ai:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(0.0)); + } + + @Test + public void trackDurationAtMostOnce() { + tracker.trackDuration(Duration.ofMillis(100)); + tracker.trackDuration(Duration.ofMillis(200)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(warnings().get(0), containsString("duration")); + } + + @Test + public void trackDurationNullIsIgnoredWithDebugLog() { + tracker.trackDuration(null); + verify(client, never()).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + } + + // ---- trackDurationOf ------------------------------------------------------ + + @Test + public void trackDurationOfReturnResultAndTracksDuration() throws Exception { + String result = tracker.trackDurationOf(() -> "hello"); + assertThat(result, is("hello")); + verify(client, times(1)).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + + @Test + public void trackDurationOfTracksDurationEvenOnException() { + try { + tracker.trackDurationOf(() -> { + throw new RuntimeException("boom"); + }); + } catch (Exception ignored) { + } + verify(client, times(1)).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + + // ---- trackTimeToFirstToken ------------------------------------------------ + + @Test + public void trackTimeToFirstTokenEmitsCorrectEvent() { + tracker.trackTimeToFirstToken(Duration.ofMillis(250)); + verify(client).trackMetric( + eq("$ld:ai:tokens:ttf"), eq(CONTEXT), eq(baseExpectedData()), eq(250.0)); + } + + @Test + public void trackTimeToFirstTokenAtMostOnce() { + tracker.trackTimeToFirstToken(Duration.ofMillis(100)); + tracker.trackTimeToFirstToken(Duration.ofMillis(200)); + verify(client, times(1)).trackMetric(eq("$ld:ai:tokens:ttf"), any(), any(), anyDouble()); + } + + @Test + public void trackTimeToFirstTokenNullIsIgnoredWithDebugLog() { + tracker.trackTimeToFirstToken(null); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:ttf"), any(), any(), anyDouble()); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + } + + // ---- trackSuccess / trackError -------------------------------------------- + + @Test + public void trackSuccessEmitsCorrectEvent() { + tracker.trackSuccess(); + verify(client).trackMetric( + eq("$ld:ai:generation:success"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackErrorEmitsCorrectEvent() { + tracker.trackError(); + verify(client).trackMetric( + eq("$ld:ai:generation:error"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackSuccessAtMostOnce() { + tracker.trackSuccess(); + tracker.trackSuccess(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackErrorAtMostOnce() { + tracker.trackError(); + tracker.trackError(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackSuccessAndErrorShareGuard_successFirst() { + tracker.trackSuccess(); + tracker.trackError(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackSuccessAndErrorShareGuard_errorFirst() { + tracker.trackError(); + tracker.trackSuccess(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackFeedback -------------------------------------------------------- + + @Test + public void trackFeedbackPositiveEmitsCorrectEvent() { + tracker.trackFeedback(FeedbackKind.POSITIVE); + verify(client).trackMetric( + eq("$ld:ai:feedback:user:positive"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackFeedbackNegativeEmitsCorrectEvent() { + tracker.trackFeedback(FeedbackKind.NEGATIVE); + verify(client).trackMetric( + eq("$ld:ai:feedback:user:negative"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackFeedbackAtMostOnce() { + tracker.trackFeedback(FeedbackKind.POSITIVE); + tracker.trackFeedback(FeedbackKind.NEGATIVE); + verify(client, times(1)).trackMetric( + eq("$ld:ai:feedback:user:positive"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:feedback:user:negative"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackFeedbackNullIsIgnoredWithDebugLog_slotNotBurned() { + tracker.trackFeedback(null); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + // Slot should not be burned — a subsequent valid call should still work + tracker.trackFeedback(FeedbackKind.POSITIVE); + verify(client, times(1)).trackMetric(eq("$ld:ai:feedback:user:positive"), any(), any(), anyDouble()); + } + + // ---- trackTokens ---------------------------------------------------------- + + @Test + public void trackTokensEmitsEventsForPositiveCounts() { + tracker.trackTokens(new TokenUsage(100, 60, 40)); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), eq(CONTEXT), eq(baseExpectedData()), eq(100.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:input"), eq(CONTEXT), eq(baseExpectedData()), eq(60.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), eq(CONTEXT), eq(baseExpectedData()), eq(40.0)); + } + + @Test + public void trackTokensSkipsZeroCounts() { + tracker.trackTokens(new TokenUsage(0, 0, 40)); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:input"), any(), any(), anyDouble()); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), any(), any(), eq(40.0)); + } + + @Test + public void trackTokensAllZeroDoesNotBurnSlot() { + tracker.trackTokens(new TokenUsage(0, 0, 0)); + // Slot not burned — next valid call should succeed + tracker.trackTokens(new TokenUsage(10, 5, 5)); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), eq(10.0)); + } + + @Test + public void trackTokensAtMostOnce() { + tracker.trackTokens(new TokenUsage(10, 5, 5)); + tracker.trackTokens(new TokenUsage(20, 10, 10)); + verify(client, times(1)).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackTokensNullIsIgnoredWithDebugLog() { + tracker.trackTokens(null); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + } + + // ---- trackToolCall -------------------------------------------------------- + + @Test + public void trackToolCallEmitsOnEveryCall() { + LDValue expectedDataWithTool = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("toolKey", "search") + .build(); + + tracker.trackToolCall("search"); + tracker.trackToolCall("search"); + tracker.trackToolCall("fetch"); + + verify(client, times(2)).trackMetric( + eq("$ld:ai:tool_call"), eq(CONTEXT), eq(expectedDataWithTool), eq(1.0)); + LDValue fetchData = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("toolKey", "fetch") + .build(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:tool_call"), eq(CONTEXT), eq(fetchData), eq(1.0)); + } + + @Test + public void trackToolCallsDelegate() { + tracker.trackToolCalls(Arrays.asList("a", "b")); + verify(client, times(2)).trackMetric(eq("$ld:ai:tool_call"), any(), any(), anyDouble()); + } + + @Test + public void trackToolCallNullIsIgnoredWithDebugLog() { + tracker.trackToolCall(null); + verify(client, never()).trackMetric(eq("$ld:ai:tool_call"), any(), any(), anyDouble()); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + } + + // ---- trackJudgeResult ----------------------------------------------------- + + @Test + public void trackJudgeResultEmitsWhenSampledAndSucceeded() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true) + .metricKey("judge-score").score(0.85) + .judgeConfigKey("my-judge") + .build(); + + LDValue expectedData = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("judgeConfigKey", "my-judge") + .build(); + + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("judge-score"), eq(CONTEXT), eq(expectedData), eq(0.85)); + } + + @Test + public void trackJudgeResultSkipsWhenNotSampled() { + JudgeResult result = JudgeResult.builder() + .sampled(false).success(true).metricKey("k").score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenNotSuccess() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(false).metricKey("k").score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenMetricKeyNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey(null).score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(any(), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenScoreNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(null).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultFiresWhenScoreIsZero() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(0.0).build(); + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("k"), any(), any(), eq(0.0)); + } + + @Test + public void trackJudgeResultOmitsJudgeConfigKeyWhenNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(1.0).judgeConfigKey(null).build(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("k"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("judgeConfigKey").isNull(), is(true)); + } + + @Test + public void trackJudgeResultIsNotAtMostOnce() { + JudgeResult r1 = JudgeResult.builder().sampled(true).success(true).metricKey("k1").score(1.0).build(); + JudgeResult r2 = JudgeResult.builder().sampled(true).success(true).metricKey("k2").score(2.0).build(); + tracker.trackJudgeResult(r1); + tracker.trackJudgeResult(r2); + verify(client).trackMetric(eq("k1"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("k2"), any(), any(), eq(2.0)); + } + + @Test + public void trackJudgeResultNullIsIgnoredWithDebugLog() { + tracker.trackJudgeResult(null); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); + } + + // ---- trackMetricsOf ------------------------------------------------------- + + @Test + public void trackMetricsOfTracksSuccessAndDurationAndTokens() throws Exception { + AIMetrics metrics = AIMetrics.builder() + .success(true) + .tokens(new TokenUsage(10, 6, 4)) + .build(); + + String result = tracker.trackMetricsOf(r -> metrics, () -> "ok"); + assertThat(result, is("ok")); + + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), eq(10.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:input"), any(), any(), eq(6.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), any(), any(), eq(4.0)); + } + + @Test + public void trackMetricsOfUsesRunnerReportedDurationWhenPresent() throws Exception { + AIMetrics metrics = AIMetrics.builder().success(true).durationMs(999L).build(); + tracker.trackMetricsOf(r -> metrics, () -> "ok"); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), eq(999.0)); + } + + @Test + public void trackMetricsOfWallClockDurationExcludesSlowExtractor() throws Exception { + // Operation returns immediately; extractor sleeps. Recorded duration must reflect only the + // operation, not the extractor work. + long extractorSleepMs = 200L; + AIMetrics metrics = AIMetrics.builder().success(true).build(); + tracker.trackMetricsOf( + r -> { try { Thread.sleep(extractorSleepMs); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); } return metrics; }, + () -> "ok"); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), durationCaptor.capture()); + assertThat( + "wall-clock duration must not include extractor time", + durationCaptor.getValue() < (double) extractorSleepMs / 2, + is(true)); + } + + @Test + public void trackMetricsOfTracksErrorAndRethrowsOnOperationException() { + try { + tracker.trackMetricsOf( + r -> AIMetrics.builder().success(true).build(), + () -> { throw new RuntimeException("ai failed"); }); + } catch (Exception e) { + assertThat(e.getMessage(), is("ai failed")); + } + verify(client).trackMetric(eq("$ld:ai:generation:error"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + @Test + public void trackMetricsOfExtractorExceptionPropagatesAndDoesNotCallTrackError() { + try { + tracker.trackMetricsOf( + r -> { throw new RuntimeException("extractor failed"); }, + () -> "ok"); + } catch (Exception e) { + assertThat(e.getMessage(), is("extractor failed")); + } + verify(client, never()).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + @Test + public void trackMetricsOfRecordsDurationWhenExtractorThrows() { + try { + tracker.trackMetricsOf( + r -> { throw new RuntimeException("extractor failed"); }, + () -> "ok"); + } catch (Exception e) { + // expected; we care that duration was recorded before the throw + } + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + + @Test + public void trackMetricsOfTracksToolCalls() throws Exception { + AIMetrics metrics = AIMetrics.builder() + .success(true) + .toolCalls(Arrays.asList("search", "fetch")) + .build(); + tracker.trackMetricsOf(r -> metrics, () -> "ok"); + verify(client, times(2)).trackMetric(eq("$ld:ai:tool_call"), any(), any(), eq(1.0)); + } + + // ---- getSummary ----------------------------------------------------------- + + @Test + public void getSummaryReturnsNullsBeforeAnyTracking() { + MetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(nullValue())); + assertThat(summary.getDurationMs(), is(nullValue())); + assertThat(summary.getTokens(), is(nullValue())); + assertThat(summary.getFeedback(), is(nullValue())); + assertThat(summary.getTimeToFirstTokenMs(), is(nullValue())); + assertThat(summary.getToolCalls(), is(nullValue())); + assertThat(summary.getResumptionToken(), is(notNullValue())); + } + + @Test + public void getSummaryReflectsAllTrackedValues() { + tracker.trackDuration(Duration.ofMillis(300)); + tracker.trackTimeToFirstToken(Duration.ofMillis(50)); + tracker.trackSuccess(); + tracker.trackFeedback(FeedbackKind.POSITIVE); + tracker.trackTokens(new TokenUsage(30, 20, 10)); + tracker.trackToolCall("search"); + tracker.trackToolCall("fetch"); + + MetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(Boolean.TRUE)); + assertThat(summary.getDurationMs(), is(300L)); + assertThat(summary.getTimeToFirstTokenMs(), is(50L)); + assertThat(summary.getFeedback(), is(FeedbackKind.POSITIVE)); + assertThat(summary.getTokens().getTotal(), is(30L)); + assertThat(summary.getToolCalls(), containsInAnyOrder("search", "fetch")); + assertThat(summary.getResumptionToken(), is(tracker.getResumptionToken())); + } + + @Test + public void getSummarySuccessIsFalseWhenErrorTracked() { + tracker.trackError(); + assertThat(tracker.getSummary().getSuccess(), is(Boolean.FALSE)); + } + + @Test + public void getSummaryToolCallsIsImmutableSnapshot() { + tracker.trackToolCall("a"); + List snapshot1 = tracker.getSummary().getToolCalls(); + tracker.trackToolCall("b"); + List snapshot2 = tracker.getSummary().getToolCalls(); + assertThat(snapshot1.size(), is(1)); + assertThat(snapshot2.size(), is(2)); + } + + // ---- variationKey omission ------------------------------------------------ + + @Test + public void variationKeyOmittedFromPayloadWhenNull() { + LDAIConfigTrackerImpl t = makeTracker(null); + t.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("variationKey").isNull(), is(true)); + } + + @Test + public void variationKeyIncludedInPayloadWhenPresent() { + tracker.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("variationKey").stringValue(), is(VARIATION_KEY)); + } + + // ---- graphKey inclusion --------------------------------------------------- + + @Test + public void graphKeyIncludedInPayloadWhenSet() { + LDAIConfigTrackerImpl t = new LDAIConfigTrackerImpl( + client, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, "my-graph", logger); + t.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("graphKey").stringValue(), is("my-graph")); + } + + @Test + public void graphKeyOmittedFromPayloadWhenNull() { + tracker.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("graphKey").isNull(), is(true)); + } + + // ---- concurrency: at-most-once under contention --------------------------- + + @Test + public void trackDurationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + AtomicInteger callCount = new AtomicInteger(0); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackDuration(Duration.ofMillis(100)); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + ArgumentCaptor valueCaptor = ArgumentCaptor.forClass(Double.class); + verify(client, times(1)).trackMetric( + eq("$ld:ai:duration:total"), any(), any(), valueCaptor.capture()); + } + + @Test + public void trackSuccessAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackSuccess(); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + // ---- constructor null checks ---------------------------------------------- + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullClient() { + new LDAIConfigTrackerImpl(null, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, null, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullContext() { + new LDAIConfigTrackerImpl(client, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, null, null, logger); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java new file mode 100644 index 00000000..bed64ca9 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java @@ -0,0 +1,164 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class ResumptionTokensTest { + + // ---- encode + decode round-trips ------------------------------------------ + + @Test + public void roundTripWithAllFields() { + String token = ResumptionTokens.encode("run-1", "config-key", "var-abc", 2, "graph-x"); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("config-key")); + assertThat(d.getVariationKey(), is("var-abc")); + assertThat(d.getVersion(), is(2)); + assertThat(d.getGraphKey(), is("graph-x")); + } + + @Test + public void roundTripWithNullVariationKey() { + String token = ResumptionTokens.encode("run-1", "config-key", null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("config-key")); + assertThat(d.getVariationKey(), is(nullValue())); + assertThat(d.getVersion(), is(1)); + assertThat(d.getGraphKey(), is(nullValue())); + } + + @Test + public void roundTripWithNullGraphKey() { + String token = ResumptionTokens.encode("run-2", "cfg", "v1", 3, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getGraphKey(), is(nullValue())); + assertThat(d.getVariationKey(), is("v1")); + } + + @Test + public void variationKeyOmittedFromTokenWhenNull() { + // Tokens with null variationKey should NOT contain the "variationKey" JSON field. + String token = ResumptionTokens.encode("r", "c", null, 1, null); + // Decode and check no variationKey + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getVariationKey(), is(nullValue())); + } + + @Test + public void graphKeyOmittedFromTokenWhenNull() { + String token = ResumptionTokens.encode("r", "c", "v", 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getGraphKey(), is(nullValue())); + } + + // ---- special character escaping ------------------------------------------- + + @Test + public void roundTripWithSpecialCharactersInKeys() { + String runId = "run\"with\\special\nchars"; + String configKey = "config\twith\rtabs"; + String token = ResumptionTokens.encode(runId, configKey, null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(runId)); + assertThat(d.getConfigKey(), is(configKey)); + } + + @Test + public void roundTripWithUnicodeInKeys() { + String runId = "run-\u00e9\u4e2d\u6587"; + String token = ResumptionTokens.encode(runId, "cfg", null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(runId)); + } + + // ---- version round-trip --------------------------------------------------- + + @Test + public void versionIsPreservedOnRoundTrip() { + String token = ResumptionTokens.encode("r", "c", null, 1, null); + assertThat(ResumptionTokens.decode(token).getVersion(), is(1)); + } + + // ---- large keys ----------------------------------------------------------- + + @Test + public void roundTripsLongKeys() { + String key = new String(new char[5000]).replace('\0', 'a'); + String token = ResumptionTokens.encode("run", key, null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getConfigKey(), is(key)); + } + + // ---- decode error handling ------------------------------------------------ + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNull() { + ResumptionTokens.decode(null); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsInvalidBase64() { + ResumptionTokens.decode("not-valid-base64!!!!"); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingRunId() { + String json = "{\"configKey\":\"c\",\"version\":1}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingConfigKey() { + String json = "{\"runId\":\"r\",\"version\":1}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingVersion() { + String json = "{\"runId\":\"r\",\"configKey\":\"c\"}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNonIntegerVersion() { + String json = "{\"runId\":\"r\",\"configKey\":\"c\",\"version\":\"one\"}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNonObjectJson() { + String json = "[\"not\",\"an\",\"object\"]"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + // ---- escapeJson helper ---------------------------------------------------- + + @Test + public void escapeJsonHandlesControlCharacters() { + assertThat(ResumptionTokens.escapeJson("\n\r\t"), is("\\n\\r\\t")); + assertThat(ResumptionTokens.escapeJson("\"hello\""), is("\\\"hello\\\"")); + assertThat(ResumptionTokens.escapeJson("back\\slash"), is("back\\\\slash")); + } + + @Test + public void escapeJsonReturnsEmptyStringForNull() { + assertThat(ResumptionTokens.escapeJson(null), is("")); + } +}