12. INT8 GEMM
General matrix multiply in INT8 requires careful handling of accumulation and output scaling to maintain accuracy across diverse weight distributions.
Quantized GEMM Mathematics
The quantized operation computes: Y = (QA * scale_A) @ (QB * scale_B) / (scale_Y)
Mapping to integer arithmetic: Y_int = round((QA @ QB) / (scale_A * scale_Y / scale_B))
The scaling factor S = scale_A * scale_Y / scale_B can be absorbed into precomputed lookup tables for small dynamic ranges.
Fast Integer GEMM with FMA
Modern GPUs provide INT8 FMA through PTX mad instructions:
__global__ void int8_gemm_kernel(
const int8_t* A, const int8_t* B, int32_t* C,
int M, int N, int K) {
int row = blockIdx.y * BLOCK_M + threadIdx.y;
int col = blockIdx.x * BLOCK_N + threadIdx.x;
int32_t sum = 0;
#pragma unroll 4
for (int k = 0; k < K; k += 4) {
int4 a_pack = *(int4*)(&A[row * K + k]);
int4 b_pack = *(int4*)(&B[k * N / 4 + col]);
sum += __dp4a(a_pack, b_pack, 0);
}
C[row * N + col] = sum;
}
__dp4a performs four-way dot product with accumulation in a single instruction.
Output Dequantization
Post-computation, apply scaling:
__global__ void dequantize_kernel(
const int32_t* C_int, float* C_float,
int M, int N, const float* scale_A,
const float* scale_B, float scale_Y) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = M * N;
if (idx < total) {
int row = idx / N;
float scale = scale_A[row] / (scale_B[0] * scale_Y);
C_float[idx] = C_int[idx] * scale;
}
}
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Implement INT8 GEMM with asymmetric quantization (zero-point) support. Handle the zero-point term correctly in accumulation.