| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- import { logEvent } from '../services/analytics/index.js'
- import { logForDebugging } from '../utils/debug.js'
- import { logForDiagnosticsNoPII } from '../utils/diagLogs.js'
- import { errorMessage } from '../utils/errors.js'
- import { jsonParse } from '../utils/slowOperations.js'
- /** Format a millisecond duration as a human-readable string (e.g. "5m 30s"). */
- function formatDuration(ms: number): string {
- if (ms < 60_000) return `${Math.round(ms / 1000)}s`
- const m = Math.floor(ms / 60_000)
- const s = Math.round((ms % 60_000) / 1000)
- return s > 0 ? `${m}m ${s}s` : `${m}m`
- }
- /**
- * Decode a JWT's payload segment without verifying the signature.
- * Strips the `sk-ant-si-` session-ingress prefix if present.
- * Returns the parsed JSON payload as `unknown`, or `null` if the
- * token is malformed or the payload is not valid JSON.
- */
- export function decodeJwtPayload(token: string): unknown | null {
- const jwt = token.startsWith('sk-ant-si-')
- ? token.slice('sk-ant-si-'.length)
- : token
- const parts = jwt.split('.')
- if (parts.length !== 3 || !parts[1]) return null
- try {
- return jsonParse(Buffer.from(parts[1], 'base64url').toString('utf8'))
- } catch {
- return null
- }
- }
- /**
- * Decode the `exp` (expiry) claim from a JWT without verifying the signature.
- * @returns The `exp` value in Unix seconds, or `null` if unparseable
- */
- export function decodeJwtExpiry(token: string): number | null {
- const payload = decodeJwtPayload(token)
- if (
- payload !== null &&
- typeof payload === 'object' &&
- 'exp' in payload &&
- typeof payload.exp === 'number'
- ) {
- return payload.exp
- }
- return null
- }
- /** Refresh buffer: request a new token before expiry. */
- const TOKEN_REFRESH_BUFFER_MS = 5 * 60 * 1000
- /** Fallback refresh interval when the new token's expiry is unknown. */
- const FALLBACK_REFRESH_INTERVAL_MS = 30 * 60 * 1000 // 30 minutes
- /** Max consecutive failures before giving up on the refresh chain. */
- const MAX_REFRESH_FAILURES = 3
- /** Retry delay when getAccessToken returns undefined. */
- const REFRESH_RETRY_DELAY_MS = 60_000
- /**
- * Creates a token refresh scheduler that proactively refreshes session tokens
- * before they expire. Used by both the standalone bridge and the REPL bridge.
- *
- * When a token is about to expire, the scheduler calls `onRefresh` with the
- * session ID and the bridge's OAuth access token. The caller is responsible
- * for delivering the token to the appropriate transport (child process stdin
- * for standalone bridge, WebSocket reconnect for REPL bridge).
- */
- export function createTokenRefreshScheduler({
- getAccessToken,
- onRefresh,
- label,
- refreshBufferMs = TOKEN_REFRESH_BUFFER_MS,
- }: {
- getAccessToken: () => string | undefined | Promise<string | undefined>
- onRefresh: (sessionId: string, oauthToken: string) => void
- label: string
- /** How long before expiry to fire refresh. Defaults to 5 min. */
- refreshBufferMs?: number
- }): {
- schedule: (sessionId: string, token: string) => void
- scheduleFromExpiresIn: (sessionId: string, expiresInSeconds: number) => void
- cancel: (sessionId: string) => void
- cancelAll: () => void
- } {
- const timers = new Map<string, ReturnType<typeof setTimeout>>()
- const failureCounts = new Map<string, number>()
- // Generation counter per session — incremented by schedule() and cancel()
- // so that in-flight async doRefresh() calls can detect when they've been
- // superseded and should skip setting follow-up timers.
- const generations = new Map<string, number>()
- function nextGeneration(sessionId: string): number {
- const gen = (generations.get(sessionId) ?? 0) + 1
- generations.set(sessionId, gen)
- return gen
- }
- function schedule(sessionId: string, token: string): void {
- const expiry = decodeJwtExpiry(token)
- if (!expiry) {
- // Token is not a decodable JWT (e.g. an OAuth token passed from the
- // REPL bridge WebSocket open handler). Preserve any existing timer
- // (such as the follow-up refresh set by doRefresh) so the refresh
- // chain is not broken.
- logForDebugging(
- `[${label}:token] Could not decode JWT expiry for sessionId=${sessionId}, token prefix=${token.slice(0, 15)}…, keeping existing timer`,
- )
- return
- }
- // Clear any existing refresh timer — we have a concrete expiry to replace it.
- const existing = timers.get(sessionId)
- if (existing) {
- clearTimeout(existing)
- }
- // Bump generation to invalidate any in-flight async doRefresh.
- const gen = nextGeneration(sessionId)
- const expiryDate = new Date(expiry * 1000).toISOString()
- const delayMs = expiry * 1000 - Date.now() - refreshBufferMs
- if (delayMs <= 0) {
- logForDebugging(
- `[${label}:token] Token for sessionId=${sessionId} expires=${expiryDate} (past or within buffer), refreshing immediately`,
- )
- void doRefresh(sessionId, gen)
- return
- }
- logForDebugging(
- `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires=${expiryDate}, buffer=${refreshBufferMs / 1000}s)`,
- )
- const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
- timers.set(sessionId, timer)
- }
- /**
- * Schedule refresh using an explicit TTL (seconds until expiry) rather
- * than decoding a JWT's exp claim. Used by callers whose JWT is opaque
- * (e.g. POST /v1/code/sessions/{id}/bridge returns expires_in directly).
- */
- function scheduleFromExpiresIn(
- sessionId: string,
- expiresInSeconds: number,
- ): void {
- const existing = timers.get(sessionId)
- if (existing) clearTimeout(existing)
- const gen = nextGeneration(sessionId)
- // Clamp to 30s floor — if refreshBufferMs exceeds the server's expires_in
- // (e.g. very large buffer for frequent-refresh testing, or server shortens
- // expires_in unexpectedly), unclamped delayMs ≤ 0 would tight-loop.
- const delayMs = Math.max(expiresInSeconds * 1000 - refreshBufferMs, 30_000)
- logForDebugging(
- `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires_in=${expiresInSeconds}s, buffer=${refreshBufferMs / 1000}s)`,
- )
- const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
- timers.set(sessionId, timer)
- }
- async function doRefresh(sessionId: string, gen: number): Promise<void> {
- let oauthToken: string | undefined
- try {
- oauthToken = await getAccessToken()
- } catch (err) {
- logForDebugging(
- `[${label}:token] getAccessToken threw for sessionId=${sessionId}: ${errorMessage(err)}`,
- { level: 'error' },
- )
- }
- // If the session was cancelled or rescheduled while we were awaiting,
- // the generation will have changed — bail out to avoid orphaned timers.
- if (generations.get(sessionId) !== gen) {
- logForDebugging(
- `[${label}:token] doRefresh for sessionId=${sessionId} stale (gen ${gen} vs ${generations.get(sessionId)}), skipping`,
- )
- return
- }
- if (!oauthToken) {
- const failures = (failureCounts.get(sessionId) ?? 0) + 1
- failureCounts.set(sessionId, failures)
- logForDebugging(
- `[${label}:token] No OAuth token available for refresh, sessionId=${sessionId} (failure ${failures}/${MAX_REFRESH_FAILURES})`,
- { level: 'error' },
- )
- logForDiagnosticsNoPII('error', 'bridge_token_refresh_no_oauth')
- // Schedule a retry so the refresh chain can recover if the token
- // becomes available again (e.g. transient cache clear during refresh).
- // Cap retries to avoid spamming on genuine failures.
- if (failures < MAX_REFRESH_FAILURES) {
- const retryTimer = setTimeout(
- doRefresh,
- REFRESH_RETRY_DELAY_MS,
- sessionId,
- gen,
- )
- timers.set(sessionId, retryTimer)
- }
- return
- }
- // Reset failure counter on successful token retrieval
- failureCounts.delete(sessionId)
- logForDebugging(
- `[${label}:token] Refreshing token for sessionId=${sessionId}: new token prefix=${oauthToken.slice(0, 15)}…`,
- )
- logEvent('tengu_bridge_token_refreshed', {})
- onRefresh(sessionId, oauthToken)
- // Schedule a follow-up refresh so long-running sessions stay authenticated.
- // Without this, the initial one-shot timer leaves the session vulnerable
- // to token expiry if it runs past the first refresh window.
- const timer = setTimeout(
- doRefresh,
- FALLBACK_REFRESH_INTERVAL_MS,
- sessionId,
- gen,
- )
- timers.set(sessionId, timer)
- logForDebugging(
- `[${label}:token] Scheduled follow-up refresh for sessionId=${sessionId} in ${formatDuration(FALLBACK_REFRESH_INTERVAL_MS)}`,
- )
- }
- function cancel(sessionId: string): void {
- // Bump generation to invalidate any in-flight async doRefresh.
- nextGeneration(sessionId)
- const timer = timers.get(sessionId)
- if (timer) {
- clearTimeout(timer)
- timers.delete(sessionId)
- }
- failureCounts.delete(sessionId)
- }
- function cancelAll(): void {
- // Bump all generations so in-flight doRefresh calls are invalidated.
- for (const sessionId of generations.keys()) {
- nextGeneration(sessionId)
- }
- for (const timer of timers.values()) {
- clearTimeout(timer)
- }
- timers.clear()
- failureCounts.clear()
- }
- return { schedule, scheduleFromExpiresIn, cancel, cancelAll }
- }
|