Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions 03_nf4_dequant/flashzxi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.18)

project(nf4_dequant LANGUAGES CXX CUDA)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_ARCHITECTURES native)

if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Build type" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS
"Debug" "Release" "RelWithDebInfo" "MinSizeRel")
endif()

add_executable(nf4_dequant
src/main.cu
src/nf4_dequant_naive.cu
src/nf4_dequant_warp8.cu
)

target_include_directories(nf4_dequant PRIVATE
${CMAKE_SOURCE_DIR}/include
)

# 单 TU/简单工程:关闭 RDC 更利于调试
set_target_properties(nf4_dequant PROPERTIES
CUDA_SEPARABLE_COMPILATION OFF
)

target_compile_options(nf4_dequant PRIVATE
$<$<AND:$<CONFIG:Debug>,$<COMPILE_LANGUAGE:CXX>>:-g -O0>
$<$<AND:$<CONFIG:Debug>,$<COMPILE_LANGUAGE:CUDA>>:-G -g -O0>

$<$<AND:$<CONFIG:Release>,$<COMPILE_LANGUAGE:CXX>>:-O3>
$<$<AND:$<CONFIG:Release>,$<COMPILE_LANGUAGE:CUDA>>:-O3>

$<$<AND:$<CONFIG:RelWithDebInfo>,$<COMPILE_LANGUAGE:CXX>>:-g -O2>
$<$<AND:$<CONFIG:RelWithDebInfo>,$<COMPILE_LANGUAGE:CUDA>>:-lineinfo -g -O2>
)
Empty file.
16 changes: 16 additions & 0 deletions 03_nf4_dequant/flashzxi/Report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## NF4 反量化
author: flashzxi

本项目是利用cuda高效计算nf4反量化,对比bitsandbytes 实现

本项目的假设:
每个block大小为64个元素

二级量化每个group包含256个block.

## 实现
总共实现了三个版本,一个最简单的naive版本,一个二级反量化和一级反量化分开计算的版本以及最终的单独kernel解两层反量化的版本。其中naive版本在`src/nf4_dequant_naive.cu`,其余两个版本都在`src/nf4_dequant_warp8.cu`



开发工程中,我尝试
127 changes: 127 additions & 0 deletions 03_nf4_dequant/flashzxi/include/common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//
// Created by core_dump on 2026/2/25.
//

#pragma once

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <chrono>
#include <iostream>

__host__ __device__ __forceinline__
float mix_mul(float fp, __half h) {
return fp * __half2float(h);
}

__host__ __device__ __forceinline__
float mix_mul(float fp, __nv_bfloat16 h) {
return fp * __bfloat162float(h);
}

__host__ __device__ __forceinline__
float f162float(__half h) {
return __half2float(h);
}

__host__ __device__ __forceinline__
float f162float(__nv_bfloat16 h) {
return __bfloat162float(h);
}


#define CUDA_CHECK(call) \
{ \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \
<< " - " << cudaGetErrorString(err) << "\n"; \
std::exit(-1); \
} \
}

class Timer {
public:
using clock = std::chrono::high_resolution_clock;

Timer() : running_(false), elapsed_ms_(0.0) {}

void tic() {
start_ = clock::now();
running_ = true;
}

double toc() {
if (!running_) {
return elapsed_ms_;
}
auto end = clock::now();
elapsed_ms_ = std::chrono::duration<double, std::milli>(end - start_).count();
running_ = false;
return elapsed_ms_;
}

double elapsed() const {
if (!running_) {
return elapsed_ms_;
}
auto now = clock::now();
return std::chrono::duration<double, std::milli>(now - start_).count();
}

void reset() {
running_ = false;
elapsed_ms_ = 0.0;
}

private:
clock::time_point start_;
bool running_;
double elapsed_ms_;
};

class Tracer {
public:
Tracer() {}

void start() {
timer_.reset();
timer_.tic();
}

void stop() {
total_elapsed_ms_ += timer_.toc();
}

Tracer& memcpy_accumulate(uint64_t cpy_size_in_byte) {
total_data_cpy_in_bytes_ += cpy_size_in_byte;
return *this;
}

double bandwidth_bytes_per_s() const {
if (total_elapsed_ms_ <= 0.0) {
return 0.0;
}
return static_cast<double>(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_;
}

double bandwidth_gib_per_s() const {
if (total_elapsed_ms_ <= 0.0) {
return 0.0;
}
constexpr double kBytesPerGiB = 1024.0 * 1024.0 * 1024.0;
return static_cast<double>(total_data_cpy_in_bytes_) * 1000.0 / total_elapsed_ms_ / kBytesPerGiB;
}

void print(std::ostream& os = std::cout) const {
os << "elapsed: " << total_elapsed_ms_ << " ms, "
<< "effective bandwidth: " << bandwidth_gib_per_s() << " GiB/s\n";
}

private:
Timer timer_;

uint64_t total_data_cpy_in_bytes_ = 0;
double total_elapsed_ms_;
};
10 changes: 10 additions & 0 deletions 03_nf4_dequant/flashzxi/include/nf4_dequant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//
// Created by core_dump on 2/25/26.
//

#pragma once

#include "quant_state.h"
void nf4_dequant_naive(const QuantState& quant_state, __half* output);
void nf4_dequant_warp8_batch32_two_phase(const QuantState& quant_state, __half* output);
void nf4_dequant_warp8_batch8_one_phase(const QuantState& quant_state, __half* output);
Loading