599 lines
18 KiB
HTML
599 lines
18 KiB
HTML
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<title>AgentDB - Flash Attention & Memory Consolidation</title>
|
|
<style>
|
|
* {
|
|
margin: 0;
|
|
padding: 0;
|
|
box-sizing: border-box;
|
|
}
|
|
|
|
body {
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
background: linear-gradient(135deg, #0f2027 0%, #203a43 50%, #2c5364 100%);
|
|
padding: 20px;
|
|
min-height: 100vh;
|
|
color: white;
|
|
}
|
|
|
|
.container {
|
|
max-width: 1400px;
|
|
margin: 0 auto;
|
|
}
|
|
|
|
header {
|
|
background: rgba(255, 255, 255, 0.1);
|
|
backdrop-filter: blur(10px);
|
|
padding: 30px;
|
|
border-radius: 10px;
|
|
margin-bottom: 30px;
|
|
border: 1px solid rgba(255, 255, 255, 0.2);
|
|
}
|
|
|
|
h1 {
|
|
color: #4fc3f7;
|
|
margin-bottom: 10px;
|
|
}
|
|
|
|
.demo-section {
|
|
background: rgba(255, 255, 255, 0.05);
|
|
backdrop-filter: blur(10px);
|
|
padding: 30px;
|
|
border-radius: 10px;
|
|
margin-bottom: 20px;
|
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
|
}
|
|
|
|
h2 {
|
|
color: #4fc3f7;
|
|
margin-bottom: 20px;
|
|
padding-bottom: 10px;
|
|
border-bottom: 2px solid rgba(79, 195, 247, 0.3);
|
|
}
|
|
|
|
button {
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
|
color: white;
|
|
border: none;
|
|
padding: 12px 24px;
|
|
border-radius: 5px;
|
|
cursor: pointer;
|
|
font-size: 16px;
|
|
margin-right: 10px;
|
|
margin-bottom: 10px;
|
|
transition: transform 0.2s;
|
|
}
|
|
|
|
button:hover {
|
|
transform: translateY(-2px);
|
|
}
|
|
|
|
.comparison {
|
|
display: grid;
|
|
grid-template-columns: 1fr 1fr;
|
|
gap: 20px;
|
|
margin-top: 20px;
|
|
}
|
|
|
|
.comparison-panel {
|
|
background: rgba(0, 0, 0, 0.3);
|
|
padding: 20px;
|
|
border-radius: 5px;
|
|
border: 1px solid rgba(79, 195, 247, 0.2);
|
|
}
|
|
|
|
.comparison-panel h3 {
|
|
color: #4fc3f7;
|
|
margin-bottom: 15px;
|
|
}
|
|
|
|
.metric-row {
|
|
display: flex;
|
|
justify-content: space-between;
|
|
padding: 10px 0;
|
|
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
|
|
}
|
|
|
|
.metric-label {
|
|
color: #b0bec5;
|
|
}
|
|
|
|
.metric-value {
|
|
color: white;
|
|
font-weight: bold;
|
|
}
|
|
|
|
.metric-value.good {
|
|
color: #4caf50;
|
|
}
|
|
|
|
.metric-value.bad {
|
|
color: #f44336;
|
|
}
|
|
|
|
canvas {
|
|
width: 100%;
|
|
height: 300px;
|
|
background: rgba(0, 0, 0, 0.3);
|
|
border-radius: 5px;
|
|
display: block;
|
|
}
|
|
|
|
.chart-container {
|
|
margin-top: 20px;
|
|
}
|
|
|
|
.progress-bar {
|
|
width: 100%;
|
|
height: 30px;
|
|
background: rgba(0, 0, 0, 0.3);
|
|
border-radius: 15px;
|
|
overflow: hidden;
|
|
margin: 20px 0;
|
|
}
|
|
|
|
.progress-fill {
|
|
height: 100%;
|
|
background: linear-gradient(90deg, #4fc3f7 0%, #667eea 100%);
|
|
transition: width 0.3s;
|
|
display: flex;
|
|
align-items: center;
|
|
justify-content: center;
|
|
font-weight: bold;
|
|
}
|
|
|
|
footer {
|
|
text-align: center;
|
|
margin-top: 30px;
|
|
opacity: 0.7;
|
|
}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<header>
|
|
<h1>Flash Attention & Memory Consolidation</h1>
|
|
<p>Compare traditional attention with Flash Attention and memory consolidation strategies</p>
|
|
</header>
|
|
|
|
<!-- Attention Comparison -->
|
|
<div class="demo-section">
|
|
<h2>Attention Mechanism Comparison</h2>
|
|
<div>
|
|
<button id="runStandard">Run Standard Attention</button>
|
|
<button id="runFlash">Run Flash Attention</button>
|
|
<button id="compareAttention">Compare Both</button>
|
|
<button id="benchmarkScaling">Benchmark Scaling</button>
|
|
</div>
|
|
|
|
<div class="comparison">
|
|
<div class="comparison-panel">
|
|
<h3>Standard Attention</h3>
|
|
<div id="standardMetrics">
|
|
<div class="metric-row">
|
|
<span class="metric-label">Time:</span>
|
|
<span class="metric-value" id="standardTime">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Memory:</span>
|
|
<span class="metric-value" id="standardMemory">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Complexity:</span>
|
|
<span class="metric-value bad">O(N²)</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="comparison-panel">
|
|
<h3>Flash Attention</h3>
|
|
<div id="flashMetrics">
|
|
<div class="metric-row">
|
|
<span class="metric-label">Time:</span>
|
|
<span class="metric-value" id="flashTime">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Memory:</span>
|
|
<span class="metric-value" id="flashMemory">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Complexity:</span>
|
|
<span class="metric-value good">O(N)</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="chart-container">
|
|
<canvas id="attentionChart"></canvas>
|
|
</div>
|
|
</div>
|
|
|
|
<!-- Memory Consolidation -->
|
|
<div class="demo-section">
|
|
<h2>Memory Consolidation</h2>
|
|
<div>
|
|
<button id="generateMemories">Generate Random Memories</button>
|
|
<button id="consolidate">Consolidate Memories</button>
|
|
<button id="compareConsolidation">Compare Strategies</button>
|
|
</div>
|
|
|
|
<div class="progress-bar">
|
|
<div class="progress-fill" id="consolidationProgress" style="width: 0%;">0%</div>
|
|
</div>
|
|
|
|
<div class="comparison">
|
|
<div class="comparison-panel">
|
|
<h3>Before Consolidation</h3>
|
|
<div id="beforeMetrics">
|
|
<div class="metric-row">
|
|
<span class="metric-label">Total Memories:</span>
|
|
<span class="metric-value" id="beforeCount">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Memory Size:</span>
|
|
<span class="metric-value" id="beforeSize">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Redundancy:</span>
|
|
<span class="metric-value bad" id="beforeRedundancy">-</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="comparison-panel">
|
|
<h3>After Consolidation</h3>
|
|
<div id="afterMetrics">
|
|
<div class="metric-row">
|
|
<span class="metric-label">Clusters:</span>
|
|
<span class="metric-value good" id="afterCount">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Memory Size:</span>
|
|
<span class="metric-value good" id="afterSize">-</span>
|
|
</div>
|
|
<div class="metric-row">
|
|
<span class="metric-label">Compression:</span>
|
|
<span class="metric-value good" id="compression">-</span>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
<div class="chart-container">
|
|
<canvas id="consolidationChart"></canvas>
|
|
</div>
|
|
</div>
|
|
|
|
<footer>
|
|
<p>AgentDB v2.0 | Flash Attention & Memory Consolidation powered by WASM</p>
|
|
</footer>
|
|
</div>
|
|
|
|
<script type="module">
|
|
import {
|
|
AttentionBrowser,
|
|
createAttention,
|
|
createFastAttention
|
|
} from '../../dist/agentdb.browser.js';
|
|
|
|
let attention = null;
|
|
let currentMemories = null;
|
|
|
|
// Initialize
|
|
async function initialize() {
|
|
attention = createAttention({
|
|
dimension: 256,
|
|
numHeads: 4,
|
|
useWASM: true
|
|
});
|
|
await attention.initialize();
|
|
}
|
|
|
|
// Standard attention (naive implementation for comparison)
|
|
function standardAttention(query, keys, values) {
|
|
const dim = query.length;
|
|
const seqLen = keys.length / dim;
|
|
const output = new Float32Array(query.length);
|
|
|
|
// Compute full attention matrix (O(N²) memory)
|
|
const attentionMatrix = new Float32Array(seqLen * seqLen);
|
|
|
|
for (let i = 0; i < seqLen; i++) {
|
|
const q = query.slice(i * dim, (i + 1) * dim);
|
|
let sumWeights = 0;
|
|
|
|
for (let j = 0; j < seqLen; j++) {
|
|
const k = keys.slice(j * dim, (j + 1) * dim);
|
|
let dot = 0;
|
|
for (let d = 0; d < dim; d++) {
|
|
dot += q[d] * k[d];
|
|
}
|
|
attentionMatrix[i * seqLen + j] = Math.exp(dot / Math.sqrt(dim));
|
|
sumWeights += attentionMatrix[i * seqLen + j];
|
|
}
|
|
|
|
// Normalize
|
|
for (let j = 0; j < seqLen; j++) {
|
|
attentionMatrix[i * seqLen + j] /= sumWeights;
|
|
}
|
|
}
|
|
|
|
// Apply to values
|
|
for (let i = 0; i < seqLen; i++) {
|
|
for (let j = 0; j < seqLen; j++) {
|
|
const weight = attentionMatrix[i * seqLen + j];
|
|
const v = values.slice(j * dim, (j + 1) * dim);
|
|
for (let d = 0; d < dim; d++) {
|
|
output[i * dim + d] += weight * v[d];
|
|
}
|
|
}
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
// Generate sample data
|
|
function generateSampleData(seqLen, dim) {
|
|
const query = new Float32Array(seqLen * dim);
|
|
const keys = new Float32Array(seqLen * dim);
|
|
const values = new Float32Array(seqLen * dim);
|
|
|
|
for (let i = 0; i < seqLen * dim; i++) {
|
|
query[i] = Math.random() - 0.5;
|
|
keys[i] = Math.random() - 0.5;
|
|
values[i] = Math.random() - 0.5;
|
|
}
|
|
|
|
return { query, keys, values };
|
|
}
|
|
|
|
// Run standard attention
|
|
document.getElementById('runStandard').addEventListener('click', async () => {
|
|
const { query, keys, values } = generateSampleData(20, 256);
|
|
|
|
const start = performance.now();
|
|
standardAttention(query, keys, values);
|
|
const duration = performance.now() - start;
|
|
|
|
const memoryUsed = (20 * 20 * 4) / 1024; // Attention matrix
|
|
|
|
document.getElementById('standardTime').textContent = `${duration.toFixed(2)}ms`;
|
|
document.getElementById('standardMemory').textContent = `${memoryUsed.toFixed(2)} KB`;
|
|
});
|
|
|
|
// Run flash attention
|
|
document.getElementById('runFlash').addEventListener('click', async () => {
|
|
const { query, keys, values } = generateSampleData(20, 256);
|
|
|
|
const start = performance.now();
|
|
await attention.flashAttention(query, keys, values);
|
|
const duration = performance.now() - start;
|
|
|
|
const memoryUsed = (20 * 256 * 4) / 1024; // Linear in sequence length
|
|
|
|
document.getElementById('flashTime').textContent = `${duration.toFixed(2)}ms`;
|
|
document.getElementById('flashMemory').textContent = `${memoryUsed.toFixed(2)} KB`;
|
|
document.getElementById('flashMemory').className = 'metric-value good';
|
|
});
|
|
|
|
// Compare both
|
|
document.getElementById('compareAttention').addEventListener('click', async () => {
|
|
document.getElementById('runStandard').click();
|
|
await new Promise(resolve => setTimeout(resolve, 100));
|
|
document.getElementById('runFlash').click();
|
|
});
|
|
|
|
// Benchmark scaling
|
|
document.getElementById('benchmarkScaling').addEventListener('click', async () => {
|
|
const canvas = document.getElementById('attentionChart');
|
|
const ctx = canvas.getContext('2d');
|
|
canvas.width = canvas.clientWidth * window.devicePixelRatio;
|
|
canvas.height = canvas.clientHeight * window.devicePixelRatio;
|
|
ctx.scale(window.devicePixelRatio, window.devicePixelRatio);
|
|
|
|
const sequenceLengths = [5, 10, 20, 30, 40];
|
|
const standardTimes = [];
|
|
const flashTimes = [];
|
|
|
|
for (const seqLen of sequenceLengths) {
|
|
const { query, keys, values } = generateSampleData(seqLen, 256);
|
|
|
|
// Standard
|
|
const start1 = performance.now();
|
|
standardAttention(query, keys, values);
|
|
standardTimes.push(performance.now() - start1);
|
|
|
|
// Flash
|
|
const start2 = performance.now();
|
|
await attention.flashAttention(query, keys, values);
|
|
flashTimes.push(performance.now() - start2);
|
|
}
|
|
|
|
// Draw chart
|
|
const width = canvas.clientWidth;
|
|
const height = canvas.clientHeight;
|
|
const padding = 40;
|
|
const chartWidth = width - 2 * padding;
|
|
const chartHeight = height - 2 * padding;
|
|
|
|
ctx.clearRect(0, 0, width, height);
|
|
|
|
// Background
|
|
ctx.fillStyle = 'rgba(0, 0, 0, 0.3)';
|
|
ctx.fillRect(0, 0, width, height);
|
|
|
|
// Axes
|
|
ctx.strokeStyle = 'rgba(255, 255, 255, 0.3)';
|
|
ctx.lineWidth = 2;
|
|
ctx.beginPath();
|
|
ctx.moveTo(padding, padding);
|
|
ctx.lineTo(padding, height - padding);
|
|
ctx.lineTo(width - padding, height - padding);
|
|
ctx.stroke();
|
|
|
|
// Plot data
|
|
const maxTime = Math.max(...standardTimes, ...flashTimes);
|
|
const xStep = chartWidth / (sequenceLengths.length - 1);
|
|
|
|
// Standard attention (red)
|
|
ctx.strokeStyle = '#f44336';
|
|
ctx.lineWidth = 3;
|
|
ctx.beginPath();
|
|
standardTimes.forEach((time, i) => {
|
|
const x = padding + i * xStep;
|
|
const y = height - padding - (time / maxTime) * chartHeight;
|
|
if (i === 0) ctx.moveTo(x, y);
|
|
else ctx.lineTo(x, y);
|
|
});
|
|
ctx.stroke();
|
|
|
|
// Flash attention (green)
|
|
ctx.strokeStyle = '#4caf50';
|
|
ctx.beginPath();
|
|
flashTimes.forEach((time, i) => {
|
|
const x = padding + i * xStep;
|
|
const y = height - padding - (time / maxTime) * chartHeight;
|
|
if (i === 0) ctx.moveTo(x, y);
|
|
else ctx.lineTo(x, y);
|
|
});
|
|
ctx.stroke();
|
|
|
|
// Labels
|
|
ctx.fillStyle = 'white';
|
|
ctx.font = '12px sans-serif';
|
|
ctx.fillText('Sequence Length →', width / 2 - 50, height - 10);
|
|
ctx.save();
|
|
ctx.translate(10, height / 2);
|
|
ctx.rotate(-Math.PI / 2);
|
|
ctx.fillText('Time (ms) →', 0, 0);
|
|
ctx.restore();
|
|
|
|
// Legend
|
|
ctx.fillStyle = '#f44336';
|
|
ctx.fillRect(width - 200, 20, 20, 20);
|
|
ctx.fillStyle = 'white';
|
|
ctx.fillText('Standard O(N²)', width - 175, 35);
|
|
|
|
ctx.fillStyle = '#4caf50';
|
|
ctx.fillRect(width - 200, 50, 20, 20);
|
|
ctx.fillStyle = 'white';
|
|
ctx.fillText('Flash O(N)', width - 175, 65);
|
|
});
|
|
|
|
// Generate memories
|
|
document.getElementById('generateMemories').addEventListener('click', () => {
|
|
const numMemories = 50;
|
|
const dim = 256;
|
|
const numClusters = 5;
|
|
|
|
currentMemories = [];
|
|
|
|
// Generate clustered memories
|
|
for (let cluster = 0; cluster < numClusters; cluster++) {
|
|
const base = new Float32Array(dim);
|
|
for (let d = 0; d < dim; d++) {
|
|
base[d] = Math.random() - 0.5;
|
|
}
|
|
|
|
const memoriesPerCluster = Math.floor(numMemories / numClusters);
|
|
for (let i = 0; i < memoriesPerCluster; i++) {
|
|
const memory = new Float32Array(dim);
|
|
for (let d = 0; d < dim; d++) {
|
|
memory[d] = base[d] + (Math.random() - 0.5) * 0.2;
|
|
}
|
|
currentMemories.push(memory);
|
|
}
|
|
}
|
|
|
|
const sizeKB = (numMemories * dim * 4) / 1024;
|
|
document.getElementById('beforeCount').textContent = numMemories;
|
|
document.getElementById('beforeSize').textContent = `${sizeKB.toFixed(2)} KB`;
|
|
document.getElementById('beforeRedundancy').textContent = 'High';
|
|
|
|
document.getElementById('consolidationProgress').style.width = '0%';
|
|
document.getElementById('consolidationProgress').textContent = '0%';
|
|
});
|
|
|
|
// Consolidate
|
|
document.getElementById('consolidate').addEventListener('click', async () => {
|
|
if (!currentMemories) {
|
|
alert('Please generate memories first');
|
|
return;
|
|
}
|
|
|
|
const progressBar = document.getElementById('consolidationProgress');
|
|
progressBar.style.width = '50%';
|
|
progressBar.textContent = '50%';
|
|
|
|
const consolidated = await attention.consolidateMemories(currentMemories, {
|
|
threshold: 0.85,
|
|
maxClusters: 10
|
|
});
|
|
|
|
progressBar.style.width = '100%';
|
|
progressBar.textContent = '100%';
|
|
|
|
const originalSize = (currentMemories.length * 256 * 4) / 1024;
|
|
const newSize = (consolidated.length * 256 * 4) / 1024;
|
|
const compressionRatio = currentMemories.length / consolidated.length;
|
|
|
|
document.getElementById('afterCount').textContent = consolidated.length;
|
|
document.getElementById('afterSize').textContent = `${newSize.toFixed(2)} KB`;
|
|
document.getElementById('compression').textContent = `${compressionRatio.toFixed(2)}x`;
|
|
|
|
// Draw consolidation chart
|
|
drawConsolidationChart(consolidated);
|
|
});
|
|
|
|
function drawConsolidationChart(consolidated) {
|
|
const canvas = document.getElementById('consolidationChart');
|
|
const ctx = canvas.getContext('2d');
|
|
canvas.width = canvas.clientWidth * window.devicePixelRatio;
|
|
canvas.height = canvas.clientHeight * window.devicePixelRatio;
|
|
ctx.scale(window.devicePixelRatio, window.devicePixelRatio);
|
|
|
|
const width = canvas.clientWidth;
|
|
const height = canvas.clientHeight;
|
|
|
|
ctx.clearRect(0, 0, width, height);
|
|
ctx.fillStyle = 'rgba(0, 0, 0, 0.3)';
|
|
ctx.fillRect(0, 0, width, height);
|
|
|
|
// Bar chart of cluster sizes
|
|
const maxCount = Math.max(...consolidated.map(c => c.count));
|
|
const barWidth = width / consolidated.length - 10;
|
|
const maxBarHeight = height - 60;
|
|
|
|
consolidated.forEach((cluster, i) => {
|
|
const barHeight = (cluster.count / maxCount) * maxBarHeight;
|
|
const x = i * (barWidth + 10) + 20;
|
|
const y = height - barHeight - 30;
|
|
|
|
const gradient = ctx.createLinearGradient(x, y, x, height - 30);
|
|
gradient.addColorStop(0, '#667eea');
|
|
gradient.addColorStop(1, '#764ba2');
|
|
|
|
ctx.fillStyle = gradient;
|
|
ctx.fillRect(x, y, barWidth, barHeight);
|
|
|
|
ctx.fillStyle = 'white';
|
|
ctx.font = '12px sans-serif';
|
|
ctx.fillText(cluster.count, x + barWidth / 2 - 5, y - 5);
|
|
});
|
|
|
|
// Label
|
|
ctx.fillStyle = 'white';
|
|
ctx.font = '14px sans-serif';
|
|
ctx.fillText('Members per Cluster', width / 2 - 60, height - 5);
|
|
}
|
|
|
|
// Initialize on load
|
|
initialize();
|
|
</script>
|
|
</body>
|
|
</html>
|