KEY INSIGHT
Classifying documents before routing them to specialized processing pipelines reduces overall processing time and enables different extraction strategies for different document types.
### Why Classify First
A single document processing pipeline optimized for invoices fails on contracts. Classification enables routing: each document type follows its optimal path. Additionally, classification metadata helps downstream systems understand document provenance.
### Rule-Based Classification
Simple classification based on file metadata and basic content analysis:
```python
def classify_simple(doc_path):
import fitz
doc = fitz.open(doc_path)
text = doc[0].get_text()[:2000] # Sample beginning
# Check for specific patterns
if any(marker in text for marker in ['INVOICE', 'Invoice #', 'Bill To:', 'Total Due']):
return 'invoice'
elif any(marker in text for marker in ['Contract', 'Agreement', 'Whereas']):
return 'contract'
elif any(marker in text for marker in ['Dear', 'Sincerely', 'Regards']):
return 'letter'
else:
return 'unknown'
doc.close()
```
Rule-based classification is fast and requires no ML model, but fragile against variations in document formatting.
### ML-Based Classification
Train a classifier on document features:
```bash
pip install scikit-learn transformers
```
```python
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
import os
# Training data: (path, label) tuples
training_data = [
('./data/invoices/inv1.pdf', 'invoice'),
('./data/invoices/inv2.pdf', 'invoice'),
('./data/contracts/contract1.pdf', 'contract'),
# ... more training examples
]
def extract_features(path):
import fitz
doc = fitz.open(path)
text = ""
for page in doc[:3]: # First 3 pages
text += page.get_text()
doc.close()
return text
# Build training set
X_train = [extract_features(path) for path, _ in training_data]
y_train = [label for _, label in training_data]
# Train classifier
pipeline = Pipeline([
('tfidf', TfidfVectorizer(max_features=5000, ngram_range=(1, 2))),
('clf', LogisticRegression(max_iter=1000))
])
pipeline.fit(X_train, y_train)
# Predict
prediction = pipeline.predict(['new_document.pdf'])[0]
print(f"Classification: {prediction}")
```
### Transformer-Based Classification
For higher accuracy on diverse document types:
```bash
pip install transformers torch
```
```python
from transformers import pipeline
from functools import lru_cache
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
@lru_cache(maxsize=1000)
def classify_transformer(text, doc_type):
candidate_labels = ["invoice", "contract", "letter", "form", "report"]
result = classifier(text[:2000], candidate_labels)
return result['labels'][0], result['scores'][0]
# Usage
text = extract_features('document.pdf')
label, confidence = classify_transformer(text, 'document')
print(f"{label} ({confidence:.2f})")
```
Zero-shot classification requires no training dataΓÇöyou specify candidate labels and the model classifies. Works well when you have 5-10 known document types.
### Multi-Modal Classification
Some documents are primarily images. Classify by visual features:
```python
import torch
from torchvision import transforms, models
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 5) # 5 document types
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_visual(doc_path):
from PIL import Image
import fitz
# Render first page as image
doc = fitz.open(doc_path)
mat = fitz.Matrix(2, 2) # 2x zoom
pix = doc[0].get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
doc.close()
image = Image.open(io.BytesIO(img_data))
tensor = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
classes = ['form', 'invoice', 'letter', 'report', 'contract']
return classes[logits.argmax().item()]
```
### Classification Confidence and Fallback
Always check confidence and route low-confidence predictions:
```python
def classify_with_fallback(doc_path):
text = extract_features(doc_path)
# Try ML classifier
label, conf = classify_transformer(text, 'document')
if conf < 0.6:
# Low confidence - use rule-based as fallback
rule_label = classify_simple(doc_path)
print(f"Low confidence ({conf:.2f}), rule-based suggests: {rule_label}")
return rule_label
return label
```