13. FP8 Inference

Chapter 13 of 18 · 15 min

FP8 formats (E4M3 and E5M2 per IEEE 754) offer near-exponential dynamic range with reduced precision, ideal for transformer inference.

FP8 Format Details

E4M3: 4-bit exponent, 3-bit mantissa, bias=7, max=448, min=2^-9 E5M2: 5-bit exponent, 2-bit mantissa, bias=15, max=57344, min=2^-14

Tensor operands often use E4M3; gradient accumulators and momentum use E5M2 for larger range.

FP8 GEMM Implementation

__device__ __half float_to_fp8_e4m3(float val) {
    __half h = __float2half(val);
    unsigned int bits = __half_as_uint(h);
    // Clamp to E4M3 range [2^-9, 448]
    bits = max(bits, 0x3F000000u);  // min FP8 value
    bits = min(bits, 0x477FE000u);  // max FP8 value
    return __uint_as_half(bits);
}

__global__ void fp8_gemm_kernel(
    const __half* A_fp16, const __half* B_fp16,
    __half* C, int M, int N, int K, float scale) {
    
    int row = blockIdx.y * BM + threadIdx.y;
    int col = blockIdx.x * BN + threadIdx.x;
    
    float acc = 0.0f;
    
    for (int k = 0; k < K; k++) {
        float a = __half2float(A_fp16[row * K + k]);
        float b = __half2float(B_fp16[k * N + col]);
        acc += a * b;
    }
    
    C[row * N + col] = float_to_fp8_e4m3(acc * scale);
}

FP8 Scaling Considerations

Automatic scaling via per-tensor or per-row quantization:

void compute_fp8_scales(const float* fp32_tensor, int size,
                        float& scale, float max_bound = 448.0f) {
    float abs_max = 0.0f;
    for (int i = 0; i < size; i++) {
        abs_max = fmax(abs_max, fabsf(fp32_tensor[i]));
    }
    scale = max_bound / abs_max;
}

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Implement FP8 attention with scaled softmax, ensuring numerical stability for large logit values.