|
| 1 | +'use client'; |
| 2 | + |
| 3 | +import { useRef, useEffect } from 'react'; |
| 4 | +import { evalChartData, unseenChartData, evalChartColors, frankaModelOrder } from '../data/evalChartData'; |
| 5 | +import pattern from 'patternomaly'; |
| 6 | + |
| 7 | +const DroidBarChart = () => { |
| 8 | + const chartRef = useRef(null); |
| 9 | + const chartInstance = useRef(null); |
| 10 | + |
| 11 | + useEffect(() => { |
| 12 | + if (typeof window === 'undefined' || !chartRef.current) return; |
| 13 | + |
| 14 | + let chartObj; |
| 15 | + |
| 16 | + const initChart = async () => { |
| 17 | + try { |
| 18 | + const ChartModule = await import('chart.js'); |
| 19 | + const { Chart, CategoryScale, LinearScale, BarController, BarElement, |
| 20 | + Title, Tooltip, Legend } = ChartModule; |
| 21 | + |
| 22 | + // Custom error bars plugin |
| 23 | + const errorBarsPlugin = { |
| 24 | + id: 'errorBars', |
| 25 | + afterDatasetsDraw(chart) { |
| 26 | + const ctx = chart.ctx; |
| 27 | + chart.data.datasets.forEach((dataset, datasetIndex) => { |
| 28 | + const meta = chart.getDatasetMeta(datasetIndex); |
| 29 | + if (!meta.hidden && dataset.errorBars) { |
| 30 | + meta.data.forEach((bar, index) => { |
| 31 | + const errorPlus = dataset.errorBars.plus[index]; |
| 32 | + const errorMinus = dataset.errorBars.minus[index]; |
| 33 | + |
| 34 | + if (errorPlus > 0 || errorMinus > 0) { |
| 35 | + const x = bar.x; |
| 36 | + const y = bar.y; |
| 37 | + const barWidth = bar.width; |
| 38 | + |
| 39 | + // Calculate error bar positions |
| 40 | + const yScale = chart.scales.y; |
| 41 | + const errorBarWidth = barWidth * 0.5; |
| 42 | + const capWidth = errorBarWidth; |
| 43 | + |
| 44 | + ctx.save(); |
| 45 | + ctx.strokeStyle = dataset.errorBars.color || '#7A7A7C'; |
| 46 | + ctx.lineWidth = dataset.errorBars.lineWidth || 1; |
| 47 | + |
| 48 | + // Draw upper error bar |
| 49 | + if (errorPlus > 0) { |
| 50 | + const yTop = yScale.getPixelForValue(yScale.getValueForPixel(y) + errorPlus); |
| 51 | + ctx.beginPath(); |
| 52 | + ctx.moveTo(x, y); |
| 53 | + ctx.lineTo(x, yTop); |
| 54 | + ctx.stroke(); |
| 55 | + |
| 56 | + // Draw cap |
| 57 | + ctx.beginPath(); |
| 58 | + ctx.moveTo(x - capWidth / 2, yTop); |
| 59 | + ctx.lineTo(x + capWidth / 2, yTop); |
| 60 | + ctx.stroke(); |
| 61 | + } |
| 62 | + |
| 63 | + // Draw lower error bar |
| 64 | + if (errorMinus > 0) { |
| 65 | + const yBottom = yScale.getPixelForValue(yScale.getValueForPixel(y) - errorMinus); |
| 66 | + ctx.beginPath(); |
| 67 | + ctx.moveTo(x, y); |
| 68 | + ctx.lineTo(x, yBottom); |
| 69 | + ctx.stroke(); |
| 70 | + |
| 71 | + // Draw cap |
| 72 | + ctx.beginPath(); |
| 73 | + ctx.moveTo(x - capWidth / 2, yBottom); |
| 74 | + ctx.lineTo(x + capWidth / 2, yBottom); |
| 75 | + ctx.stroke(); |
| 76 | + } |
| 77 | + |
| 78 | + ctx.restore(); |
| 79 | + } |
| 80 | + }); |
| 81 | + } |
| 82 | + }); |
| 83 | + } |
| 84 | + }; |
| 85 | + |
| 86 | + Chart.register( |
| 87 | + CategoryScale, |
| 88 | + LinearScale, |
| 89 | + BarController, |
| 90 | + BarElement, |
| 91 | + Title, |
| 92 | + Tooltip, |
| 93 | + Legend, |
| 94 | + errorBarsPlugin |
| 95 | + ); |
| 96 | + |
| 97 | + if (chartInstance.current) { |
| 98 | + chartInstance.current.destroy(); |
| 99 | + chartInstance.current = null; |
| 100 | + } |
| 101 | + |
| 102 | + const ctx = chartRef.current.getContext('2d'); |
| 103 | + |
| 104 | + // Prepare datasets for Seen vs Unseen |
| 105 | + const datasets = frankaModelOrder.map((modelKey) => { |
| 106 | + const seenData = evalChartData.franka.models[modelKey]; |
| 107 | + const unseenData = unseenChartData.droidFranka.models[modelKey]; |
| 108 | + const color = evalChartColors[modelKey]; |
| 109 | + const isPretrained = modelKey.includes('(Pretrained)'); |
| 110 | + |
| 111 | + // Use success rate (second element) for both seen and unseen |
| 112 | + const seenSuccessRate = seenData.mean[1]; |
| 113 | + const seenStderr = seenData.stderr[1]; |
| 114 | + |
| 115 | + const unseenSuccessRate = unseenData.mean[1]; |
| 116 | + const unseenStderr = unseenData.stderr[1]; |
| 117 | + |
| 118 | + // Use diagonal stripes for pretrained models |
| 119 | + const backgroundColor = isPretrained |
| 120 | + ? pattern.draw('diagonal', color, 'white', 6) |
| 121 | + : color; |
| 122 | + |
| 123 | + return { |
| 124 | + label: modelKey.replace('pi0.5', 'π₀.₅'), |
| 125 | + data: [seenSuccessRate, unseenSuccessRate], |
| 126 | + backgroundColor: backgroundColor, |
| 127 | + borderColor: 'white', |
| 128 | + borderWidth: isPretrained ? 0.6 : 0.5, |
| 129 | + borderRadius: 0, |
| 130 | + barThickness: 'flex', |
| 131 | + maxBarThickness: 60, |
| 132 | + errorBars: { |
| 133 | + plus: [seenStderr, unseenStderr], |
| 134 | + minus: [seenStderr, unseenStderr], |
| 135 | + color: '#7A7A7C', |
| 136 | + lineWidth: 1, |
| 137 | + width: '50%' |
| 138 | + }, |
| 139 | + modelKey: modelKey |
| 140 | + }; |
| 141 | + }); |
| 142 | + |
| 143 | + chartObj = new Chart(ctx, { |
| 144 | + type: 'bar', |
| 145 | + data: { |
| 146 | + labels: ['Seen', 'Unseen'], |
| 147 | + datasets: datasets |
| 148 | + }, |
| 149 | + options: { |
| 150 | + responsive: true, |
| 151 | + maintainAspectRatio: false, |
| 152 | + interaction: { |
| 153 | + mode: 'index', |
| 154 | + intersect: false, |
| 155 | + }, |
| 156 | + plugins: { |
| 157 | + legend: { |
| 158 | + display: true, |
| 159 | + position: 'bottom', |
| 160 | + align: 'center', |
| 161 | + labels: { |
| 162 | + boxWidth: 15, |
| 163 | + padding: 12, |
| 164 | + font: { |
| 165 | + family: "'NVIDIA Sans', sans-serif", |
| 166 | + size: 12 |
| 167 | + }, |
| 168 | + color: '#48484A', |
| 169 | + usePointStyle: false |
| 170 | + }, |
| 171 | + onClick: (e, legendItem, legend) => { |
| 172 | + const index = legendItem.datasetIndex; |
| 173 | + const chart = legend.chart; |
| 174 | + const meta = chart.getDatasetMeta(index); |
| 175 | + meta.hidden = meta.hidden === null ? !chart.data.datasets[index].hidden : null; |
| 176 | + chart.update(); |
| 177 | + } |
| 178 | + }, |
| 179 | + tooltip: { |
| 180 | + enabled: true, |
| 181 | + backgroundColor: 'rgba(255, 255, 255, 0.95)', |
| 182 | + titleColor: '#48484A', |
| 183 | + bodyColor: '#48484A', |
| 184 | + titleFont: { |
| 185 | + family: "'NVIDIA Sans', sans-serif", |
| 186 | + size: 13, |
| 187 | + weight: 'bold' |
| 188 | + }, |
| 189 | + bodyFont: { |
| 190 | + family: "'NVIDIA Sans', sans-serif", |
| 191 | + size: 12 |
| 192 | + }, |
| 193 | + borderColor: '#D1D1D6', |
| 194 | + borderWidth: 1, |
| 195 | + displayColors: true, |
| 196 | + padding: 12, |
| 197 | + callbacks: { |
| 198 | + label: function(context) { |
| 199 | + const label = context.dataset.label || ''; |
| 200 | + const value = context.parsed.y; |
| 201 | + return `${label}: ${value.toFixed(1)}%`; |
| 202 | + }, |
| 203 | + labelColor: function(context) { |
| 204 | + return { |
| 205 | + borderColor: context.dataset.backgroundColor, |
| 206 | + backgroundColor: context.dataset.backgroundColor, |
| 207 | + borderWidth: 2, |
| 208 | + borderRadius: 2 |
| 209 | + }; |
| 210 | + } |
| 211 | + } |
| 212 | + }, |
| 213 | + title: { |
| 214 | + display: true, |
| 215 | + text: 'DROID-Franka: Seen vs Unseen Average Success Rate', |
| 216 | + align: 'center', |
| 217 | + font: { |
| 218 | + family: "'NVIDIA Sans', sans-serif", |
| 219 | + size: 16, |
| 220 | + weight: 'bold' |
| 221 | + }, |
| 222 | + color: '#48484A', |
| 223 | + padding: { top: 10, bottom: 20 } |
| 224 | + } |
| 225 | + }, |
| 226 | + scales: { |
| 227 | + y: { |
| 228 | + beginAtZero: true, |
| 229 | + max: 100, |
| 230 | + ticks: { |
| 231 | + stepSize: 20, |
| 232 | + callback: function(value) { |
| 233 | + return value + '%'; |
| 234 | + }, |
| 235 | + font: { |
| 236 | + family: "'NVIDIA Sans', sans-serif", |
| 237 | + size: 14 |
| 238 | + }, |
| 239 | + color: '#636366' |
| 240 | + }, |
| 241 | + grid: { |
| 242 | + color: '#B0B0B5', |
| 243 | + lineWidth: 0.9, |
| 244 | + borderDash: [2, 2] |
| 245 | + }, |
| 246 | + title: { |
| 247 | + display: true, |
| 248 | + text: 'Average Success Rate (%)', |
| 249 | + font: { |
| 250 | + family: "'NVIDIA Sans', sans-serif", |
| 251 | + size: 14, |
| 252 | + weight: 'bold' |
| 253 | + }, |
| 254 | + color: '#636366', |
| 255 | + padding: { top: 0, bottom: 10 } |
| 256 | + } |
| 257 | + }, |
| 258 | + x: { |
| 259 | + ticks: { |
| 260 | + font: { |
| 261 | + family: "'NVIDIA Sans', sans-serif", |
| 262 | + size: 13, |
| 263 | + weight: '600' |
| 264 | + }, |
| 265 | + color: '#48484A' |
| 266 | + }, |
| 267 | + grid: { |
| 268 | + display: false |
| 269 | + } |
| 270 | + } |
| 271 | + }, |
| 272 | + animation: { |
| 273 | + duration: 0 |
| 274 | + }, |
| 275 | + layout: { |
| 276 | + padding: { |
| 277 | + top: 0, |
| 278 | + bottom: 10 |
| 279 | + } |
| 280 | + } |
| 281 | + } |
| 282 | + }); |
| 283 | + |
| 284 | + chartInstance.current = chartObj; |
| 285 | + } catch (error) { |
| 286 | + console.error("Failed to load Chart.js:", error); |
| 287 | + } |
| 288 | + }; |
| 289 | + |
| 290 | + initChart(); |
| 291 | + |
| 292 | + return () => { |
| 293 | + if (chartInstance.current) { |
| 294 | + chartInstance.current.destroy(); |
| 295 | + } |
| 296 | + }; |
| 297 | + }, []); |
| 298 | + |
| 299 | + return ( |
| 300 | + <div style={{ |
| 301 | + width: '100%', |
| 302 | + marginTop: '1.5rem', |
| 303 | + marginBottom: '2rem', |
| 304 | + padding: '1.5rem', |
| 305 | + backgroundColor: 'white', |
| 306 | + borderRadius: '8px', |
| 307 | + position: 'relative' |
| 308 | + }}> |
| 309 | + <div style={{ |
| 310 | + height: '400px', |
| 311 | + position: 'relative' |
| 312 | + }}> |
| 313 | + <canvas ref={chartRef} style={{ touchAction: 'manipulation' }}></canvas> |
| 314 | + </div> |
| 315 | + </div> |
| 316 | + ); |
| 317 | +}; |
| 318 | + |
| 319 | +export default DroidBarChart; |
0 commit comments