Skip to content

Duchstf/ECLAIR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

168 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Ultrafast On-Chip Online Learning via Spline Locality in Kolmogorov–Arnold Networks

Authors: Duc Hoang*, Aarush Gupta*, Philip Harris. MIT. *Equal Contributions.

PDF (arXiv)

KAN Teaser

Overview

This repository contains the reference implementation for the paper "Ultrafast On-Chip Online Learning via Spline Locality in Kolmogorov–Arnold Networks."

The code framework is named after the French pastry ÉCLAIR, short for:

Efficient Continual Learning and Adaptive Inference in Real-time

Accordingly, while the method builds on the Kolmogorov–Arnold Network (KAN) architecture, the main model object in code is called eclair.


Installation

Clone the Repository

git clone --recursive git@github.com:Duchstf/ECLAIR.git
cd ECLAIR

If you already cloned without --recursive, initialize the submodule with:

git submodule update --init --recursive

This fetches the HLS_arbitrary_Precision_Types library required for fixed-point arithmetic.

Conda Environment

# Create the environment
conda env create -f environment.yml

# Activate the environment
conda activate eclair

# Update environment (if needed)
conda env update -f environment.yml --prune

Additional Requirements

  • C++ Compiler: g++ with C++11 support (for CPU simulation)
  • Vitis HLS: Required for FPGA synthesis (optional, for hardware builds only)

Quick Start

import sys
sys.path.append('src')  # or add to PYTHONPATH
from eclair import Eclair
import numpy as np

# Define model configuration
config = {
    'layer_sizes': [2, 4, 1],           # Network architecture
    'model_precision': (8, 4),          # Fixed-point: (total_bits, int_bits)
    'input_precision': (8, 4),
    'output_precision': (8, 4),
    'grid_range': [-1, 1],              # Spline grid bounds
    'grid_size': 10,                    # Number of grid intervals
    'spline_order': 3,                  # B-spline order (cubic)
    'lut_bits': 4,                      # LUT resolution = 2^lut_bits
    'model_name': 'my_model',           # Output directory name
    'learning_rate': 0.1,
    'fpga_part': 'xcvu13p-flga2577-2-e',
    'clock_period': '5',
    'params_type': 'ram_2p',
    'params_impl': 'lutram',
    'context_type': 'ram_1p',
    'context_impl': 'lutram',
}

# Initialize and compile model
model = Eclair(config)
model.compile()

# Online learning loop
for x, y_true in data_stream:
    # Forward pass (inference only)
    pred = model.call(x, feedback=0, zero_grad=1)
    
    # Compute gradient feedback (e.g., MSE derivative)
    feedback = 2 * (pred - y_true)
    
    # Backward pass (weight update)
    model.call(x, feedback, zero_grad=0)

API Reference

Eclair Class

The main KAN-based model with spline locality for efficient on-chip learning.

Constructor

Eclair(config, random_seed=None)
Parameter Type Description
config dict Model configuration (see below)
random_seed int Optional seed for weight initialization

Configuration Dictionary

Key Type Description
layer_sizes list[int] Network dimensions, e.g., [input, h1, ..., output]
model_precision tuple(int,int) or 'float' Weight precision: (total_bits, int_bits)
input_precision tuple(int,int) or 'float' Input precision
output_precision tuple(int,int) or 'float' Output precision
grid_range list[float] Spline grid bounds: [min, max]
grid_size int Number of uniform grid intervals
spline_order int B-spline order (1=linear, 3=cubic)
lut_bits int Basis function LUT resolution (2^lut_bits entries)
model_name str Directory name for generated firmware
learning_rate float SGD learning rate
fpga_part str Target FPGA part number
clock_period str Target clock period in nanoseconds
params_type str HLS storage type for weights ('ram_2p')
params_impl str HLS implementation ('lutram', 'bram')
context_type str HLS storage type for context
context_impl str HLS implementation for context
lut_partition_type str Optional: 'complete' or 'block'
lut_partition_factor int Required if lut_partition_type='block'

Methods

Method Description
compile() Compiles generated C++ to shared library for CPU simulation
call(input, feedback, zero_grad) Forward/backward pass (see below)
build() Runs Vitis HLS to synthesize FPGA IP

call(input, feedback, zero_grad)

Parameter Type Description
input np.ndarray or list Input vector
feedback np.ndarray, list, or 0 Gradient feedback for weight update
zero_grad int 1 = inference only, 0 = apply weight update

Returns: Model output (scalar if output_dim=1, else np.ndarray)


MLP Class

Baseline MLP with online backpropagation for comparison.

from mlp import MLP

config = {
    'layer_sizes': [2, 16, 16, 1],
    'model_precision': 'ap_fixed<8, 4, AP_RND_CONV, AP_SAT>',  # Full HLS type string
    'input_precision': 'ap_fixed<8, 4, AP_RND_CONV, AP_SAT>',
    'output_precision': 'ap_fixed<8, 4, AP_RND_CONV, AP_SAT>',
    'model_name': 'mlp_model',
    'learning_rate': 0.1,
    'fpga_part': 'xcvu13p-flga2577-2-e',
    'clock_period': '5',
}

model = MLP(config)
model.compile()

The MLP class has the same compile(), call(), and build() methods as Eclair.


Demo Applications

The demos/ directory contains complete examples:

Demo Description
function_tracking/ 1D online regression with drifting target function
acrobot/ Reinforcement learning with TD(n) actor-critic
cartpole/ Online policy learning for CartPole control
qubit/ Superconducting qubit state classification with drift

Running a Demo

cd demos/function_tracking
python eclair_train.py

Precision Configuration

Fixed-Point Format

For ECLAIR, use tuple format (total_bits, integer_bits):

'model_precision': (8, 4)  # 8-bit fixed-point with 4 integer bits (4 fractional)

This will still get compiled to 'ap_fixed<W, I, AP_RND_CONV, AP_SAT>' where W and I are specified above.

For MLP, use full HLS type string:

'model_precision': 'ap_fixed<8, 4, AP_RND_CONV, AP_SAT>'

Floating-Point

'model_precision': 'float'

FPGA Synthesis

To generate synthesizable HLS IP:

model = Eclair(config)
model.compile()  # CPU simulation
model.build()    # Vitis HLS synthesis (requires vitis_hls in PATH)

Generated firmware is placed in {model_name}/firmware/:

my_model/
└── firmware/
    ├── eclair.cpp      # Top-level HLS function
    ├── eclair.h
    ├── defines.h       # Architecture parameters
    ├── parameters.h    # Weights & LUTs
    ├── components.h    # Layer implementations
    ├── bridge.cpp      # CPU test interface
    └── build.tcl       # Vitis HLS build script

Citation

If you use this code, please cite:

@misc{hoang2026ultrafastonchiponlinelearning,
      title={Ultrafast On-chip Online Learning via Spline Locality in Kolmogorov-Arnold Networks}, 
      author={Duc Hoang and Aarush Gupta and Philip Harris},
      year={2026},
      eprint={2602.02056},
      archivePrefix={arXiv},
      primaryClass={cs.AR},
      url={https://arxiv.org/abs/2602.02056}, 
}

About

Ultrafast On-Chip Online Learning via Spline Locality in Kolmogorov–Arnold Networks

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors