415 lines
15 KiB
JavaScript
415 lines
15 KiB
JavaScript
/**
|
|
* Decision Transformer
|
|
*
|
|
* Implements sequence modeling approach for RL:
|
|
* - Trajectory as sequence: (s, a, R, s, a, R, ...)
|
|
* - Return-conditioned generation
|
|
* - Causal transformer attention
|
|
* - Offline RL from trajectories
|
|
*
|
|
* Performance Target: <10ms per forward pass
|
|
*/
|
|
/**
|
|
* Default Decision Transformer configuration
|
|
*/
|
|
export const DEFAULT_DT_CONFIG = {
|
|
algorithm: 'decision-transformer',
|
|
learningRate: 0.0001,
|
|
gamma: 0.99,
|
|
entropyCoef: 0,
|
|
valueLossCoef: 0,
|
|
maxGradNorm: 1.0,
|
|
epochs: 1,
|
|
miniBatchSize: 64,
|
|
contextLength: 20,
|
|
numHeads: 4,
|
|
numLayers: 2,
|
|
hiddenDim: 64,
|
|
embeddingDim: 32,
|
|
dropout: 0.1,
|
|
};
|
|
/**
|
|
* Decision Transformer Implementation
|
|
*/
|
|
export class DecisionTransformer {
|
|
config;
|
|
// Embeddings
|
|
stateEmbed;
|
|
actionEmbed;
|
|
returnEmbed;
|
|
posEmbed;
|
|
// Transformer layers (simplified)
|
|
attentionWeights;
|
|
ffnWeights;
|
|
// Output head
|
|
actionHead;
|
|
// Training buffer
|
|
trajectoryBuffer = [];
|
|
// Dimensions
|
|
stateDim = 768;
|
|
numActions = 4;
|
|
// Statistics
|
|
updateCount = 0;
|
|
avgLoss = 0;
|
|
constructor(config = {}) {
|
|
this.config = { ...DEFAULT_DT_CONFIG, ...config };
|
|
// Initialize embeddings
|
|
this.stateEmbed = this.initEmbedding(this.stateDim, this.config.embeddingDim);
|
|
this.actionEmbed = this.initEmbedding(this.numActions, this.config.embeddingDim);
|
|
this.returnEmbed = this.initEmbedding(1, this.config.embeddingDim);
|
|
this.posEmbed = this.initEmbedding(this.config.contextLength * 3, this.config.embeddingDim);
|
|
// Initialize transformer layers
|
|
this.attentionWeights = [];
|
|
this.ffnWeights = [];
|
|
for (let l = 0; l < this.config.numLayers; l++) {
|
|
// Attention: Q, K, V, O projections
|
|
this.attentionWeights.push([
|
|
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // Q
|
|
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // K
|
|
this.initWeight(this.config.embeddingDim, this.config.hiddenDim), // V
|
|
this.initWeight(this.config.hiddenDim, this.config.embeddingDim), // O
|
|
]);
|
|
// FFN: up and down projections
|
|
this.ffnWeights.push([
|
|
this.initWeight(this.config.embeddingDim, this.config.hiddenDim * 4),
|
|
this.initWeight(this.config.hiddenDim * 4, this.config.embeddingDim),
|
|
]);
|
|
}
|
|
// Action prediction head
|
|
this.actionHead = this.initWeight(this.config.embeddingDim, this.numActions);
|
|
}
|
|
/**
|
|
* Add trajectory for training
|
|
*/
|
|
addTrajectory(trajectory) {
|
|
if (trajectory.isComplete && trajectory.steps.length > 0) {
|
|
this.trajectoryBuffer.push(trajectory);
|
|
// Keep buffer bounded
|
|
if (this.trajectoryBuffer.length > 1000) {
|
|
this.trajectoryBuffer = this.trajectoryBuffer.slice(-1000);
|
|
}
|
|
}
|
|
}
|
|
/**
|
|
* Train on buffered trajectories
|
|
* Target: <10ms per batch
|
|
*/
|
|
train() {
|
|
const startTime = performance.now();
|
|
if (this.trajectoryBuffer.length === 0) {
|
|
return { loss: 0, accuracy: 0 };
|
|
}
|
|
// Sample mini-batch of trajectories
|
|
const batchSize = Math.min(this.config.miniBatchSize, this.trajectoryBuffer.length);
|
|
const batch = [];
|
|
for (let i = 0; i < batchSize; i++) {
|
|
const idx = Math.floor(Math.random() * this.trajectoryBuffer.length);
|
|
batch.push(this.trajectoryBuffer[idx]);
|
|
}
|
|
let totalLoss = 0;
|
|
let correct = 0;
|
|
let total = 0;
|
|
for (const trajectory of batch) {
|
|
// Create sequence from trajectory
|
|
const sequence = this.createSequence(trajectory);
|
|
if (sequence.length < 2)
|
|
continue;
|
|
// Forward pass and compute loss
|
|
for (let t = 1; t < sequence.length; t++) {
|
|
// Use context up to position t
|
|
const context = sequence.slice(Math.max(0, t - this.config.contextLength), t);
|
|
const target = sequence[t];
|
|
// Predict action
|
|
const predicted = this.forward(context);
|
|
const predictedAction = this.argmax(predicted);
|
|
// Cross-entropy loss
|
|
const loss = -Math.log(predicted[target.action] + 1e-8);
|
|
totalLoss += loss;
|
|
if (predictedAction === target.action) {
|
|
correct++;
|
|
}
|
|
total++;
|
|
// Gradient update (simplified)
|
|
this.updateWeights(context, target.action, predicted);
|
|
}
|
|
}
|
|
this.updateCount++;
|
|
this.avgLoss = total > 0 ? totalLoss / total : 0;
|
|
const elapsed = performance.now() - startTime;
|
|
if (elapsed > 10) {
|
|
console.warn(`DT training exceeded target: ${elapsed.toFixed(2)}ms > 10ms`);
|
|
}
|
|
return {
|
|
loss: this.avgLoss,
|
|
accuracy: total > 0 ? correct / total : 0,
|
|
};
|
|
}
|
|
/**
|
|
* Get action conditioned on target return
|
|
*/
|
|
getAction(states, actions, targetReturn) {
|
|
// Build sequence
|
|
const sequence = [];
|
|
let returnToGo = targetReturn;
|
|
for (let i = 0; i < states.length; i++) {
|
|
sequence.push({
|
|
returnToGo,
|
|
state: states[i],
|
|
action: actions[i] ?? 0,
|
|
timestep: i,
|
|
});
|
|
// Decrease return-to-go by estimated reward
|
|
if (i > 0) {
|
|
returnToGo -= 0.1; // Default reward decrement for inference
|
|
}
|
|
}
|
|
// Forward pass
|
|
const logits = this.forward(sequence);
|
|
return this.argmax(logits);
|
|
}
|
|
/**
|
|
* Forward pass through transformer
|
|
*/
|
|
forward(sequence) {
|
|
// Embed sequence elements
|
|
const seqLen = Math.min(sequence.length, this.config.contextLength);
|
|
const embedDim = this.config.embeddingDim;
|
|
// Initialize hidden states (simplified: stack all modalities)
|
|
const hidden = new Float32Array(seqLen * 3 * embedDim);
|
|
for (let t = 0; t < seqLen; t++) {
|
|
const entry = sequence[sequence.length - seqLen + t];
|
|
const baseIdx = t * 3 * embedDim;
|
|
// Embed return
|
|
for (let d = 0; d < embedDim; d++) {
|
|
hidden[baseIdx + d] = entry.returnToGo * this.returnEmbed[d];
|
|
}
|
|
// Embed state
|
|
for (let d = 0; d < embedDim; d++) {
|
|
let stateSum = 0;
|
|
for (let s = 0; s < Math.min(entry.state.length, this.stateDim); s++) {
|
|
stateSum += entry.state[s] * this.stateEmbed[s * embedDim + d];
|
|
}
|
|
hidden[baseIdx + embedDim + d] = stateSum;
|
|
}
|
|
// Embed action
|
|
for (let d = 0; d < embedDim; d++) {
|
|
hidden[baseIdx + 2 * embedDim + d] = this.actionEmbed[entry.action * embedDim + d];
|
|
}
|
|
// Add positional embedding
|
|
for (let d = 0; d < 3 * embedDim; d++) {
|
|
hidden[baseIdx + d] += this.posEmbed[t * 3 * embedDim + d] || 0;
|
|
}
|
|
}
|
|
// Apply transformer layers
|
|
for (let l = 0; l < this.config.numLayers; l++) {
|
|
hidden.set(this.transformerLayer(hidden, seqLen * 3, l));
|
|
}
|
|
// Extract last state position embedding for action prediction
|
|
const lastStateIdx = (seqLen * 3 - 2) * embedDim;
|
|
const lastState = hidden.slice(lastStateIdx, lastStateIdx + embedDim);
|
|
// Action prediction
|
|
const logits = new Float32Array(this.numActions);
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
let sum = 0;
|
|
for (let d = 0; d < embedDim; d++) {
|
|
sum += lastState[d] * this.actionHead[d * this.numActions + a];
|
|
}
|
|
logits[a] = sum;
|
|
}
|
|
return this.softmax(logits);
|
|
}
|
|
/**
|
|
* Get statistics
|
|
*/
|
|
getStats() {
|
|
return {
|
|
updateCount: this.updateCount,
|
|
bufferSize: this.trajectoryBuffer.length,
|
|
avgLoss: this.avgLoss,
|
|
contextLength: this.config.contextLength,
|
|
numLayers: this.config.numLayers,
|
|
};
|
|
}
|
|
// ==========================================================================
|
|
// Private Methods
|
|
// ==========================================================================
|
|
initEmbedding(inputDim, outputDim) {
|
|
const embed = new Float32Array(inputDim * outputDim);
|
|
const scale = Math.sqrt(2 / inputDim);
|
|
for (let i = 0; i < embed.length; i++) {
|
|
embed[i] = (Math.random() - 0.5) * scale;
|
|
}
|
|
return embed;
|
|
}
|
|
initWeight(inputDim, outputDim) {
|
|
const weight = new Float32Array(inputDim * outputDim);
|
|
const scale = Math.sqrt(2 / inputDim);
|
|
for (let i = 0; i < weight.length; i++) {
|
|
weight[i] = (Math.random() - 0.5) * scale;
|
|
}
|
|
return weight;
|
|
}
|
|
createSequence(trajectory) {
|
|
const sequence = [];
|
|
// Compute returns-to-go
|
|
const rewards = trajectory.steps.map(s => s.reward);
|
|
const returnsToGo = new Array(rewards.length).fill(0);
|
|
let cumReturn = 0;
|
|
for (let t = rewards.length - 1; t >= 0; t--) {
|
|
cumReturn = rewards[t] + this.config.gamma * cumReturn;
|
|
returnsToGo[t] = cumReturn;
|
|
}
|
|
// Create sequence entries
|
|
for (let t = 0; t < trajectory.steps.length; t++) {
|
|
sequence.push({
|
|
returnToGo: returnsToGo[t],
|
|
state: trajectory.steps[t].stateAfter,
|
|
action: this.hashAction(trajectory.steps[t].action),
|
|
timestep: t,
|
|
});
|
|
}
|
|
return sequence;
|
|
}
|
|
transformerLayer(hidden, seqLen, layerIdx) {
|
|
const embedDim = this.config.embeddingDim;
|
|
const hiddenDim = this.config.hiddenDim;
|
|
const numHeads = this.config.numHeads;
|
|
const headDim = hiddenDim / numHeads;
|
|
const output = new Float32Array(hidden.length);
|
|
// Self-attention (simplified causal)
|
|
const [Wq, Wk, Wv, Wo] = this.attentionWeights[layerIdx];
|
|
// Compute Q, K, V for all positions
|
|
const Q = new Float32Array(seqLen * hiddenDim);
|
|
const K = new Float32Array(seqLen * hiddenDim);
|
|
const V = new Float32Array(seqLen * hiddenDim);
|
|
for (let pos = 0; pos < seqLen; pos++) {
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
let qSum = 0, kSum = 0, vSum = 0;
|
|
for (let d = 0; d < embedDim; d++) {
|
|
const hiddenVal = hidden[pos * embedDim + d];
|
|
qSum += hiddenVal * Wq[d * hiddenDim + h];
|
|
kSum += hiddenVal * Wk[d * hiddenDim + h];
|
|
vSum += hiddenVal * Wv[d * hiddenDim + h];
|
|
}
|
|
Q[pos * hiddenDim + h] = qSum;
|
|
K[pos * hiddenDim + h] = kSum;
|
|
V[pos * hiddenDim + h] = vSum;
|
|
}
|
|
}
|
|
// Causal attention
|
|
for (let pos = 0; pos < seqLen; pos++) {
|
|
// Compute attention scores for current position
|
|
const scores = new Float32Array(pos + 1);
|
|
for (let k = 0; k <= pos; k++) {
|
|
let score = 0;
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
score += Q[pos * hiddenDim + h] * K[k * hiddenDim + h];
|
|
}
|
|
scores[k] = score / Math.sqrt(headDim);
|
|
}
|
|
// Softmax
|
|
const maxScore = Math.max(...scores);
|
|
let sumExp = 0;
|
|
for (let k = 0; k <= pos; k++) {
|
|
scores[k] = Math.exp(scores[k] - maxScore);
|
|
sumExp += scores[k];
|
|
}
|
|
for (let k = 0; k <= pos; k++) {
|
|
scores[k] /= sumExp;
|
|
}
|
|
// Weighted sum of values
|
|
const attnOut = new Float32Array(hiddenDim);
|
|
for (let k = 0; k <= pos; k++) {
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
attnOut[h] += scores[k] * V[k * hiddenDim + h];
|
|
}
|
|
}
|
|
// Output projection
|
|
for (let d = 0; d < embedDim; d++) {
|
|
let sum = hidden[pos * embedDim + d]; // Residual
|
|
for (let h = 0; h < hiddenDim; h++) {
|
|
sum += attnOut[h] * Wo[h * embedDim + d];
|
|
}
|
|
output[pos * embedDim + d] = sum;
|
|
}
|
|
}
|
|
// FFN with residual
|
|
const [Wup, Wdown] = this.ffnWeights[layerIdx];
|
|
const ffnHiddenDim = hiddenDim * 4;
|
|
for (let pos = 0; pos < seqLen; pos++) {
|
|
// Up projection + GELU
|
|
const ffnHidden = new Float32Array(ffnHiddenDim);
|
|
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
let sum = 0;
|
|
for (let d = 0; d < embedDim; d++) {
|
|
sum += output[pos * embedDim + d] * Wup[d * ffnHiddenDim + h];
|
|
}
|
|
// GELU approximation
|
|
ffnHidden[h] = sum * 0.5 * (1 + Math.tanh(0.7978845608 * (sum + 0.044715 * sum * sum * sum)));
|
|
}
|
|
// Down projection
|
|
for (let d = 0; d < embedDim; d++) {
|
|
let sum = output[pos * embedDim + d]; // Residual
|
|
for (let h = 0; h < ffnHiddenDim; h++) {
|
|
sum += ffnHidden[h] * Wdown[h * embedDim + d];
|
|
}
|
|
output[pos * embedDim + d] = sum;
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
updateWeights(context, targetAction, predicted) {
|
|
// Simplified gradient update for action head
|
|
const lr = this.config.learningRate;
|
|
const embedDim = this.config.embeddingDim;
|
|
// Gradient of cross-entropy
|
|
const grad = new Float32Array(this.numActions);
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
grad[a] = predicted[a] - (a === targetAction ? 1 : 0);
|
|
}
|
|
// Update action head (simplified)
|
|
for (let d = 0; d < embedDim; d++) {
|
|
for (let a = 0; a < this.numActions; a++) {
|
|
this.actionHead[d * this.numActions + a] -= lr * grad[a] * 0.1;
|
|
}
|
|
}
|
|
}
|
|
softmax(logits) {
|
|
const max = Math.max(...logits);
|
|
const exps = new Float32Array(logits.length);
|
|
let sum = 0;
|
|
for (let i = 0; i < logits.length; i++) {
|
|
exps[i] = Math.exp(logits[i] - max);
|
|
sum += exps[i];
|
|
}
|
|
for (let i = 0; i < exps.length; i++) {
|
|
exps[i] /= sum;
|
|
}
|
|
return exps;
|
|
}
|
|
argmax(values) {
|
|
let maxIdx = 0;
|
|
let maxVal = values[0];
|
|
for (let i = 1; i < values.length; i++) {
|
|
if (values[i] > maxVal) {
|
|
maxVal = values[i];
|
|
maxIdx = i;
|
|
}
|
|
}
|
|
return maxIdx;
|
|
}
|
|
hashAction(action) {
|
|
let hash = 0;
|
|
for (let i = 0; i < action.length; i++) {
|
|
hash = (hash * 31 + action.charCodeAt(i)) % this.numActions;
|
|
}
|
|
return hash;
|
|
}
|
|
}
|
|
/**
|
|
* Factory function
|
|
*/
|
|
export function createDecisionTransformer(config) {
|
|
return new DecisionTransformer(config);
|
|
}
|
|
//# sourceMappingURL=decision-transformer.js.map
|