05. Multi-Label Classification

Chapter 5 of 18 · 15 min

Standard classification assigns exactly one label to each input instance. Multi-label classification permits multiple simultaneous labels, reflecting real-world complexity where categories overlap. A news article might simultaneously belong to "Politics," "Economics," and "Technology" categories.

Binary relevance treats multi-label problems as multiple independent binary classification tasks. Each label receives its own classifier determining presence or absence. While conceptually simple, binary relevance ignores label correlations—a financial news story correlated with both "Markets" and "Banking" likely contains vocabulary patterns unique to that intersection.

Local LLMs handle multi-label classification through prompt reformulation that specifies label enumeration and output format requirements. Threshold calibration influences results significantly—the same model achieves different precision-recall tradeoffs by adjusting the confidence cutoff for positive label assignment.

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "local-llama3-for-classification"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=10,
    problem_type="multi_label_classification"
)

labels = ["politics", "economy", "technology", "sports", 
          "entertainment", "science", "health", "world",
          "business", "crime"]

def classify_multilabel(text, threshold=0.5):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits).squeeze()
    
    predictions = []
    for i, prob in enumerate(probs):
        if prob.item() > threshold:
            predictions.append(labels[i])
    
    return predictions if predictions else ["unknown"]

text = "Federal Reserve announces interest rate decision affecting tech sector investments"
result = classify_multilabel(text, threshold=0.5)
print(result)

Label dependencies emerge in structured domains. Medical coding systems like ICD-10 contain hierarchical relationships where parent categories imply child category presence. Product taxonomies follow similar inclusion rules. Prompt engineering can encode these dependencies by instructing the model to prefer specific label relationships.

Threshold optimization requires validation data with balanced representation across label combinations. Grid search over threshold values (typically 0.3 to 0.7 in 0.05 increments) identifies optimal operating points for F1 macro or subset accuracy metrics. The optimal threshold depends heavily on application requirements—high-stakes applications typically favor precision over recall.

EXERCISE

Implement multi-label classification for a domain of your choice. Generate synthetic validation data representing label co-occurrence patterns. Tune thresholds and report precision, recall, and F1 per label.