/** * Decision Transformer * * Implements sequence modeling approach for RL: * - Trajectory as sequence: (s, a, R, s, a, R, ...) * - Return-conditioned generation * - Causal transformer attention * - Offline RL from trajectories * * Performance Target: <10ms per forward pass */ /** * Default Decision Transformer configuration */ export const DEFAULT_DT_CONFIG = { algorithm: 'decision-transformer', learningRate: 0.0001, gamma: 0.99, entropyCoef: 0, valueLossCoef: 0, maxGradNorm: 1.0, epochs: 1, miniBatchSize: 64, contextLength: 20, numHeads: 4, numLayers: 2, hiddenDim: 64, embeddingDim: 32, dropout: 0.1, }; /** * Decision Transformer Implementation */ export class DecisionTransformer { config; // Embeddings stateEmbed; actionEmbed; returnEmbed; posEmbed; // Transformer layers (simplified) attentionWeights; ffnWeights; // Output head actionHead; // Training buffer trajectoryBuffer = []; // Dimensions stateDim = 768; numActions = 4; // Statistics updateCount = 0; avgLoss = 0; constructor(config = {}) { this.config = { ...DEFAULT_DT_CONFIG, ...config }; // Initialize embeddings this.stateEmbed = this.initEmbedding(this.stateDim, this.config.embeddingDim); this.actionEmbed = this.initEmbedding(this.numActions, this.config.embeddingDim); this.returnEmbed = this.initEmbedding(1, this.config.embeddingDim); this.posEmbed = this.initEmbedding(this.config.contextLength * 3, this.config.embeddingDim); // Initialize transformer layers this.attentionWeights = []; this.ffnWeights = []; for (let l = 0; l < this.config.numLayers; l++) { // Attention: Q, K, V, O projections this.attentionWeights.push([ this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // Q this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // K this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // V this.initWeight(this.config.hiddenDim, this.config.embeddingDim), // O ]); // FFN: up and down projections this.ffnWeights.push([ this.initWeight(this.config.embeddingDim, this.config.hiddenDim * 4), this.initWeight(this.config.hiddenDim * 4, this.config.embeddingDim), ]); } // Action prediction head this.actionHead = this.initWeight(this.config.embeddingDim, this.numActions); } /** * Add trajectory for training */ addTrajectory(trajectory) { if (trajectory.isComplete && trajectory.steps.length > 0) { this.trajectoryBuffer.push(trajectory); // Keep buffer bounded if (this.trajectoryBuffer.length > 1000) { this.trajectoryBuffer = this.trajectoryBuffer.slice(-1000); } } } /** * Train on buffered trajectories * Target: <10ms per batch */ train() { const startTime = performance.now(); if (this.trajectoryBuffer.length === 0) { return { loss: 0, accuracy: 0 }; } // Sample mini-batch of trajectories const batchSize = Math.min(this.config.miniBatchSize, this.trajectoryBuffer.length); const batch = []; for (let i = 0; i < batchSize; i++) { const idx = Math.floor(Math.random() * this.trajectoryBuffer.length); batch.push(this.trajectoryBuffer[idx]); } let totalLoss = 0; let correct = 0; let total = 0; for (const trajectory of batch) { // Create sequence from trajectory const sequence = this.createSequence(trajectory); if (sequence.length < 2) continue; // Forward pass and compute loss for (let t = 1; t < sequence.length; t++) { // Use context up to position t const context = sequence.slice(Math.max(0, t - this.config.contextLength), t); const target = sequence[t]; // Predict action const predicted = this.forward(context); const predictedAction = this.argmax(predicted); // Cross-entropy loss const loss = -Math.log(predicted[target.action] + 1e-8); totalLoss += loss; if (predictedAction === target.action) { correct++; } total++; // Gradient update (simplified) this.updateWeights(context, target.action, predicted); } } this.updateCount++; this.avgLoss = total > 0 ? totalLoss / total : 0; const elapsed = performance.now() - startTime; if (elapsed > 10) { console.warn(`DT training exceeded target: ${elapsed.toFixed(2)}ms > 10ms`); } return { loss: this.avgLoss, accuracy: total > 0 ? correct / total : 0, }; } /** * Get action conditioned on target return */ getAction(states, actions, targetReturn) { // Build sequence const sequence = []; let returnToGo = targetReturn; for (let i = 0; i < states.length; i++) { sequence.push({ returnToGo, state: states[i], action: actions[i] ?? 0, timestep: i, }); // Decrease return-to-go by estimated reward if (i > 0) { returnToGo -= 0.1; // Default reward decrement for inference } } // Forward pass const logits = this.forward(sequence); return this.argmax(logits); } /** * Forward pass through transformer */ forward(sequence) { // Embed sequence elements const seqLen = Math.min(sequence.length, this.config.contextLength); const embedDim = this.config.embeddingDim; // Initialize hidden states (simplified: stack all modalities) const hidden = new Float32Array(seqLen * 3 * embedDim); for (let t = 0; t < seqLen; t++) { const entry = sequence[sequence.length - seqLen + t]; const baseIdx = t * 3 * embedDim; // Embed return for (let d = 0; d < embedDim; d++) { hidden[baseIdx + d] = entry.returnToGo * this.returnEmbed[d]; } // Embed state for (let d = 0; d < embedDim; d++) { let stateSum = 0; for (let s = 0; s < Math.min(entry.state.length, this.stateDim); s++) { stateSum += entry.state[s] * this.stateEmbed[s * embedDim + d]; } hidden[baseIdx + embedDim + d] = stateSum; } // Embed action for (let d = 0; d < embedDim; d++) { hidden[baseIdx + 2 * embedDim + d] = this.actionEmbed[entry.action * embedDim + d]; } // Add positional embedding for (let d = 0; d < 3 * embedDim; d++) { hidden[baseIdx + d] += this.posEmbed[t * 3 * embedDim + d] || 0; } } // Apply transformer layers for (let l = 0; l < this.config.numLayers; l++) { hidden.set(this.transformerLayer(hidden, seqLen * 3, l)); } // Extract last state position embedding for action prediction const lastStateIdx = (seqLen * 3 - 2) * embedDim; const lastState = hidden.slice(lastStateIdx, lastStateIdx + embedDim); // Action prediction const logits = new Float32Array(this.numActions); for (let a = 0; a < this.numActions; a++) { let sum = 0; for (let d = 0; d < embedDim; d++) { sum += lastState[d] * this.actionHead[d * this.numActions + a]; } logits[a] = sum; } return this.softmax(logits); } /** * Get statistics */ getStats() { return { updateCount: this.updateCount, bufferSize: this.trajectoryBuffer.length, avgLoss: this.avgLoss, contextLength: this.config.contextLength, numLayers: this.config.numLayers, }; } // ========================================================================== // Private Methods // ========================================================================== initEmbedding(inputDim, outputDim) { const embed = new Float32Array(inputDim * outputDim); const scale = Math.sqrt(2 / inputDim); for (let i = 0; i < embed.length; i++) { embed[i] = (Math.random() - 0.5) * scale; } return embed; } initWeight(inputDim, outputDim) { const weight = new Float32Array(inputDim * outputDim); const scale = Math.sqrt(2 / inputDim); for (let i = 0; i < weight.length; i++) { weight[i] = (Math.random() - 0.5) * scale; } return weight; } createSequence(trajectory) { const sequence = []; // Compute returns-to-go const rewards = trajectory.steps.map(s => s.reward); const returnsToGo = new Array(rewards.length).fill(0); let cumReturn = 0; for (let t = rewards.length - 1; t >= 0; t--) { cumReturn = rewards[t] + this.config.gamma * cumReturn; returnsToGo[t] = cumReturn; } // Create sequence entries for (let t = 0; t < trajectory.steps.length; t++) { sequence.push({ returnToGo: returnsToGo[t], state: trajectory.steps[t].stateAfter, action: this.hashAction(trajectory.steps[t].action), timestep: t, }); } return sequence; } transformerLayer(hidden, seqLen, layerIdx) { const embedDim = this.config.embeddingDim; const hiddenDim = this.config.hiddenDim; const numHeads = this.config.numHeads; const headDim = hiddenDim / numHeads; const output = new Float32Array(hidden.length); // Self-attention (simplified causal) const [Wq, Wk, Wv, Wo] = this.attentionWeights[layerIdx]; // Compute Q, K, V for all positions const Q = new Float32Array(seqLen * hiddenDim); const K = new Float32Array(seqLen * hiddenDim); const V = new Float32Array(seqLen * hiddenDim); for (let pos = 0; pos < seqLen; pos++) { for (let h = 0; h < hiddenDim; h++) { let qSum = 0, kSum = 0, vSum = 0; for (let d = 0; d < embedDim; d++) { const hiddenVal = hidden[pos * embedDim + d]; qSum += hiddenVal * Wq[d * hiddenDim + h]; kSum += hiddenVal * Wk[d * hiddenDim + h]; vSum += hiddenVal * Wv[d * hiddenDim + h]; } Q[pos * hiddenDim + h] = qSum; K[pos * hiddenDim + h] = kSum; V[pos * hiddenDim + h] = vSum; } } // Causal attention for (let pos = 0; pos < seqLen; pos++) { // Compute attention scores for current position const scores = new Float32Array(pos + 1); for (let k = 0; k <= pos; k++) { let score = 0; for (let h = 0; h < hiddenDim; h++) { score += Q[pos * hiddenDim + h] * K[k * hiddenDim + h]; } scores[k] = score / Math.sqrt(headDim); } // Softmax const maxScore = Math.max(...scores); let sumExp = 0; for (let k = 0; k <= pos; k++) { scores[k] = Math.exp(scores[k] - maxScore); sumExp += scores[k]; } for (let k = 0; k <= pos; k++) { scores[k] /= sumExp; } // Weighted sum of values const attnOut = new Float32Array(hiddenDim); for (let k = 0; k <= pos; k++) { for (let h = 0; h < hiddenDim; h++) { attnOut[h] += scores[k] * V[k * hiddenDim + h]; } } // Output projection for (let d = 0; d < embedDim; d++) { let sum = hidden[pos * embedDim + d]; // Residual for (let h = 0; h < hiddenDim; h++) { sum += attnOut[h] * Wo[h * embedDim + d]; } output[pos * embedDim + d] = sum; } } // FFN with residual const [Wup, Wdown] = this.ffnWeights[layerIdx]; const ffnHiddenDim = hiddenDim * 4; for (let pos = 0; pos < seqLen; pos++) { // Up projection + GELU const ffnHidden = new Float32Array(ffnHiddenDim); for (let h = 0; h < ffnHiddenDim; h++) { let sum = 0; for (let d = 0; d < embedDim; d++) { sum += output[pos * embedDim + d] * Wup[d * ffnHiddenDim + h]; } // GELU approximation ffnHidden[h] = sum * 0.5 * (1 + Math.tanh(0.7978845608 * (sum + 0.044715 * sum * sum * sum))); } // Down projection for (let d = 0; d < embedDim; d++) { let sum = output[pos * embedDim + d]; // Residual for (let h = 0; h < ffnHiddenDim; h++) { sum += ffnHidden[h] * Wdown[h * embedDim + d]; } output[pos * embedDim + d] = sum; } } return output; } updateWeights(context, targetAction, predicted) { // Simplified gradient update for action head const lr = this.config.learningRate; const embedDim = this.config.embeddingDim; // Gradient of cross-entropy const grad = new Float32Array(this.numActions); for (let a = 0; a < this.numActions; a++) { grad[a] = predicted[a] - (a === targetAction ? 1 : 0); } // Update action head (simplified) for (let d = 0; d < embedDim; d++) { for (let a = 0; a < this.numActions; a++) { this.actionHead[d * this.numActions + a] -= lr * grad[a] * 0.1; } } } softmax(logits) { const max = Math.max(...logits); const exps = new Float32Array(logits.length); let sum = 0; for (let i = 0; i < logits.length; i++) { exps[i] = Math.exp(logits[i] - max); sum += exps[i]; } for (let i = 0; i < exps.length; i++) { exps[i] /= sum; } return exps; } argmax(values) { let maxIdx = 0; let maxVal = values[0]; for (let i = 1; i < values.length; i++) { if (values[i] > maxVal) { maxVal = values[i]; maxIdx = i; } } return maxIdx; } hashAction(action) { let hash = 0; for (let i = 0; i < action.length; i++) { hash = (hash * 31 + action.charCodeAt(i)) % this.numActions; } return hash; } } /** * Factory function */ export function createDecisionTransformer(config) { return new DecisionTransformer(config); } //# sourceMappingURL=decision-transformer.js.map