04. Data Augmentation

Chapter 4 of 18 · 15 min

Data augmentation increases effective dataset size without collecting new data. Done wrong, it introduces artifacts that the model learns instead of real patterns.

Image Augmentations

import torchvision.transforms as T

train_transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),      # Not just RandomCrop
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=15),                      # Can create impossible images
    T.ColorJitter(brightness=0.2, contrast=0.2),       # Lighting variation
    T.RandomGrayscale(p=0.1),                         # Occasional grayscale
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225])
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

The Augmentation Trap

Too much augmentation destroys the signal. A 45-degree rotation on a digit-recognition dataset creates patterns that don't exist in reality. A brightness jitter of ±2.0 creates pure noise.

Test augmentations by applying them to a single image 20 times and asking: "Would a human label these consistently?" If not, the augmentation is too aggressive.

Cutout and Mixup

class Cutout:
    def __init__(self, size=16):
        self.size = size
    
    def __call__(self, img):
        h, w = img.shape[1], img.shape[2]
        mask = torch.ones(h, w, dtype=torch.float32)
        y = torch.randint(0, h, (1,)).item()
        x = torch.randint(0, w, (1,)).item()
        
        y1 = max(0, y - self.size // 2)
        y2 = min(h, y + self.size // 2)
        x1 = max(0, x - self.size // 2)
        x2 = min(w, x + self.size // 2)
        
        mask[y1:y2, x1:x2] = 0
        return img * mask.unsqueeze(0)

def mixup_data(x, y, alpha=1.0):
    """Mixup augmentation for better generalization."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Apply your augmentation pipeline to 10 images. Save the results. Review them—identify which augmentations create unrealistic samples and reduce their magnitude.