jwtUtils.ts 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import { logEvent } from '../services/analytics/index.js'
  2. import { logForDebugging } from '../utils/debug.js'
  3. import { logForDiagnosticsNoPII } from '../utils/diagLogs.js'
  4. import { errorMessage } from '../utils/errors.js'
  5. import { jsonParse } from '../utils/slowOperations.js'
  6. /** Format a millisecond duration as a human-readable string (e.g. "5m 30s"). */
  7. function formatDuration(ms: number): string {
  8. if (ms < 60_000) return `${Math.round(ms / 1000)}s`
  9. const m = Math.floor(ms / 60_000)
  10. const s = Math.round((ms % 60_000) / 1000)
  11. return s > 0 ? `${m}m ${s}s` : `${m}m`
  12. }
  13. /**
  14. * Decode a JWT's payload segment without verifying the signature.
  15. * Strips the `sk-ant-si-` session-ingress prefix if present.
  16. * Returns the parsed JSON payload as `unknown`, or `null` if the
  17. * token is malformed or the payload is not valid JSON.
  18. */
  19. export function decodeJwtPayload(token: string): unknown | null {
  20. const jwt = token.startsWith('sk-ant-si-')
  21. ? token.slice('sk-ant-si-'.length)
  22. : token
  23. const parts = jwt.split('.')
  24. if (parts.length !== 3 || !parts[1]) return null
  25. try {
  26. return jsonParse(Buffer.from(parts[1], 'base64url').toString('utf8'))
  27. } catch {
  28. return null
  29. }
  30. }
  31. /**
  32. * Decode the `exp` (expiry) claim from a JWT without verifying the signature.
  33. * @returns The `exp` value in Unix seconds, or `null` if unparseable
  34. */
  35. export function decodeJwtExpiry(token: string): number | null {
  36. const payload = decodeJwtPayload(token)
  37. if (
  38. payload !== null &&
  39. typeof payload === 'object' &&
  40. 'exp' in payload &&
  41. typeof payload.exp === 'number'
  42. ) {
  43. return payload.exp
  44. }
  45. return null
  46. }
  47. /** Refresh buffer: request a new token before expiry. */
  48. const TOKEN_REFRESH_BUFFER_MS = 5 * 60 * 1000
  49. /** Fallback refresh interval when the new token's expiry is unknown. */
  50. const FALLBACK_REFRESH_INTERVAL_MS = 30 * 60 * 1000 // 30 minutes
  51. /** Max consecutive failures before giving up on the refresh chain. */
  52. const MAX_REFRESH_FAILURES = 3
  53. /** Retry delay when getAccessToken returns undefined. */
  54. const REFRESH_RETRY_DELAY_MS = 60_000
  55. /**
  56. * Creates a token refresh scheduler that proactively refreshes session tokens
  57. * before they expire. Used by both the standalone bridge and the REPL bridge.
  58. *
  59. * When a token is about to expire, the scheduler calls `onRefresh` with the
  60. * session ID and the bridge's OAuth access token. The caller is responsible
  61. * for delivering the token to the appropriate transport (child process stdin
  62. * for standalone bridge, WebSocket reconnect for REPL bridge).
  63. */
  64. export function createTokenRefreshScheduler({
  65. getAccessToken,
  66. onRefresh,
  67. label,
  68. refreshBufferMs = TOKEN_REFRESH_BUFFER_MS,
  69. }: {
  70. getAccessToken: () => string | undefined | Promise<string | undefined>
  71. onRefresh: (sessionId: string, oauthToken: string) => void
  72. label: string
  73. /** How long before expiry to fire refresh. Defaults to 5 min. */
  74. refreshBufferMs?: number
  75. }): {
  76. schedule: (sessionId: string, token: string) => void
  77. scheduleFromExpiresIn: (sessionId: string, expiresInSeconds: number) => void
  78. cancel: (sessionId: string) => void
  79. cancelAll: () => void
  80. } {
  81. const timers = new Map<string, ReturnType<typeof setTimeout>>()
  82. const failureCounts = new Map<string, number>()
  83. // Generation counter per session — incremented by schedule() and cancel()
  84. // so that in-flight async doRefresh() calls can detect when they've been
  85. // superseded and should skip setting follow-up timers.
  86. const generations = new Map<string, number>()
  87. function nextGeneration(sessionId: string): number {
  88. const gen = (generations.get(sessionId) ?? 0) + 1
  89. generations.set(sessionId, gen)
  90. return gen
  91. }
  92. function schedule(sessionId: string, token: string): void {
  93. const expiry = decodeJwtExpiry(token)
  94. if (!expiry) {
  95. // Token is not a decodable JWT (e.g. an OAuth token passed from the
  96. // REPL bridge WebSocket open handler). Preserve any existing timer
  97. // (such as the follow-up refresh set by doRefresh) so the refresh
  98. // chain is not broken.
  99. logForDebugging(
  100. `[${label}:token] Could not decode JWT expiry for sessionId=${sessionId}, token prefix=${token.slice(0, 15)}…, keeping existing timer`,
  101. )
  102. return
  103. }
  104. // Clear any existing refresh timer — we have a concrete expiry to replace it.
  105. const existing = timers.get(sessionId)
  106. if (existing) {
  107. clearTimeout(existing)
  108. }
  109. // Bump generation to invalidate any in-flight async doRefresh.
  110. const gen = nextGeneration(sessionId)
  111. const expiryDate = new Date(expiry * 1000).toISOString()
  112. const delayMs = expiry * 1000 - Date.now() - refreshBufferMs
  113. if (delayMs <= 0) {
  114. logForDebugging(
  115. `[${label}:token] Token for sessionId=${sessionId} expires=${expiryDate} (past or within buffer), refreshing immediately`,
  116. )
  117. void doRefresh(sessionId, gen)
  118. return
  119. }
  120. logForDebugging(
  121. `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires=${expiryDate}, buffer=${refreshBufferMs / 1000}s)`,
  122. )
  123. const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
  124. timers.set(sessionId, timer)
  125. }
  126. /**
  127. * Schedule refresh using an explicit TTL (seconds until expiry) rather
  128. * than decoding a JWT's exp claim. Used by callers whose JWT is opaque
  129. * (e.g. POST /v1/code/sessions/{id}/bridge returns expires_in directly).
  130. */
  131. function scheduleFromExpiresIn(
  132. sessionId: string,
  133. expiresInSeconds: number,
  134. ): void {
  135. const existing = timers.get(sessionId)
  136. if (existing) clearTimeout(existing)
  137. const gen = nextGeneration(sessionId)
  138. // Clamp to 30s floor — if refreshBufferMs exceeds the server's expires_in
  139. // (e.g. very large buffer for frequent-refresh testing, or server shortens
  140. // expires_in unexpectedly), unclamped delayMs ≤ 0 would tight-loop.
  141. const delayMs = Math.max(expiresInSeconds * 1000 - refreshBufferMs, 30_000)
  142. logForDebugging(
  143. `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires_in=${expiresInSeconds}s, buffer=${refreshBufferMs / 1000}s)`,
  144. )
  145. const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
  146. timers.set(sessionId, timer)
  147. }
  148. async function doRefresh(sessionId: string, gen: number): Promise<void> {
  149. let oauthToken: string | undefined
  150. try {
  151. oauthToken = await getAccessToken()
  152. } catch (err) {
  153. logForDebugging(
  154. `[${label}:token] getAccessToken threw for sessionId=${sessionId}: ${errorMessage(err)}`,
  155. { level: 'error' },
  156. )
  157. }
  158. // If the session was cancelled or rescheduled while we were awaiting,
  159. // the generation will have changed — bail out to avoid orphaned timers.
  160. if (generations.get(sessionId) !== gen) {
  161. logForDebugging(
  162. `[${label}:token] doRefresh for sessionId=${sessionId} stale (gen ${gen} vs ${generations.get(sessionId)}), skipping`,
  163. )
  164. return
  165. }
  166. if (!oauthToken) {
  167. const failures = (failureCounts.get(sessionId) ?? 0) + 1
  168. failureCounts.set(sessionId, failures)
  169. logForDebugging(
  170. `[${label}:token] No OAuth token available for refresh, sessionId=${sessionId} (failure ${failures}/${MAX_REFRESH_FAILURES})`,
  171. { level: 'error' },
  172. )
  173. logForDiagnosticsNoPII('error', 'bridge_token_refresh_no_oauth')
  174. // Schedule a retry so the refresh chain can recover if the token
  175. // becomes available again (e.g. transient cache clear during refresh).
  176. // Cap retries to avoid spamming on genuine failures.
  177. if (failures < MAX_REFRESH_FAILURES) {
  178. const retryTimer = setTimeout(
  179. doRefresh,
  180. REFRESH_RETRY_DELAY_MS,
  181. sessionId,
  182. gen,
  183. )
  184. timers.set(sessionId, retryTimer)
  185. }
  186. return
  187. }
  188. // Reset failure counter on successful token retrieval
  189. failureCounts.delete(sessionId)
  190. logForDebugging(
  191. `[${label}:token] Refreshing token for sessionId=${sessionId}: new token prefix=${oauthToken.slice(0, 15)}…`,
  192. )
  193. logEvent('tengu_bridge_token_refreshed', {})
  194. onRefresh(sessionId, oauthToken)
  195. // Schedule a follow-up refresh so long-running sessions stay authenticated.
  196. // Without this, the initial one-shot timer leaves the session vulnerable
  197. // to token expiry if it runs past the first refresh window.
  198. const timer = setTimeout(
  199. doRefresh,
  200. FALLBACK_REFRESH_INTERVAL_MS,
  201. sessionId,
  202. gen,
  203. )
  204. timers.set(sessionId, timer)
  205. logForDebugging(
  206. `[${label}:token] Scheduled follow-up refresh for sessionId=${sessionId} in ${formatDuration(FALLBACK_REFRESH_INTERVAL_MS)}`,
  207. )
  208. }
  209. function cancel(sessionId: string): void {
  210. // Bump generation to invalidate any in-flight async doRefresh.
  211. nextGeneration(sessionId)
  212. const timer = timers.get(sessionId)
  213. if (timer) {
  214. clearTimeout(timer)
  215. timers.delete(sessionId)
  216. }
  217. failureCounts.delete(sessionId)
  218. }
  219. function cancelAll(): void {
  220. // Bump all generations so in-flight doRefresh calls are invalidated.
  221. for (const sessionId of generations.keys()) {
  222. nextGeneration(sessionId)
  223. }
  224. for (const timer of timers.values()) {
  225. clearTimeout(timer)
  226. }
  227. timers.clear()
  228. failureCounts.clear()
  229. }
  230. return { schedule, scheduleFromExpiresIn, cancel, cancelAll }
  231. }