DeepTab is a Python library for deep learning on tabular data, built on PyTorch and Lightning with a scikit-learn compatible API. It offers 15 neural architectures, from Mamba-inspired state space models and Transformers to tree ensembles and MLP baselines, each available as a classifier, regressor, or distributional (LSS) model. One fit/predict/evaluate workflow covers everyday modeling, architecture research, and production deployment.
- Familiar interface. A scikit-learn
fit/predict/evaluateAPI that drops into existing pipelines, includingGridSearchCV. - Automatic preprocessing. Feature-type detection, encoding, scaling, and missing-value handling are built in.
- One model, three tasks. Every architecture ships as a classifier, a regressor, and a distributional (
LSS) variant for uncertainty quantification. - A broad model zoo. 15 stable architectures plus experimental models, all behind the same interface, with selection guidance.
- Built for real data. Mixed feature types, class imbalance, GPU acceleration, and early stopping work out of the box.
v2.0 is a ground-up restructuring of DeepTab. The high-level estimator API (MambularClassifier().fit(...)) is largely unchanged, but the internal package layout, configuration objects, and import paths have moved.
β οΈ Upgrading from v1? Packages were reorganised, theDefault<Arch>Configclasses were renamed to<Arch>Config, and the data modules were renamed toTabularDataModule/TabularDataset. Code that only uses the high-level estimators mostly keeps working; code that imported internal modules needs updating. See the FAQ for v1 support and upgrade notes.
- Split-config API: The model, preprocessing, and training each have their own configuration object, so you can tune one concern without disturbing the others. This is the first thing you reach for in v2.
- Typed data layer:
TabularDataset,TabularDataModule, andFeatureSchemagive the data pipeline an explicit, inspectable contract, with stratified splitting controlled throughTrainerConfig.
- New stable models: AutoInt, ENODE, and TabR.
- New experimental models: Tangos, Trompt, and ModernNCA, under evaluation for promotion.
- Observability and experiment tracking:
ObservabilityConfigadds structured lifecycle logging viastructlogand one-line MLflow or TensorBoard tracking, with every run saved to an organised directory tree. It is opt-in and silent by default. - Registry-driven training: Every
torch.optimoptimizer, learning-rate scheduler, and loss is selectable by name throughTrainerConfig, and you can register your own at runtime. - Unified metrics:
deeptab.metricsships 25+ metric classes for regression, classification, and distributional models, auto-selected per task through a registry. - Reproducibility:
set_seedandseed_contextseed Python, NumPy, and PyTorch across CPU, CUDA, and MPS, including the DataLoader and sampler generators.
- Deployment-safe inference:
InferenceModelwraps a fitted estimator in a read-only prediction surface with schema validation and task-type enforcement. Training methods are deliberately absent, so a served model cannot be re-fitted by accident. - Self-describing artifacts: save and load go through a single
.deeptabformat that bundles the architecture, feature schema, preprocessing, task type, and package versions alongside the weights, so a saved model carries everything needed to reload it.
- Rebuilt from the ground up: Getting Started, Core Concepts, and the Model Zoo.
- End-to-end tutorials: runnable walkthroughs with Colab covering imbalanced classification, skewed regression, uncertainty quantification, hyperparameter tuning, and observability.
from deeptab.models import MambularClassifier
# Initialize and fit (sklearn-compatible)
model = MambularClassifier()
model.fit(X_train, y_train, max_epochs=50)
# Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)That's it! DeepTab handles preprocessing, batching, and training automatically.
Works with pandas & numpy: Pass DataFrames or arrays, and DeepTab auto-detects feature types.
DeepTab provides 15 stable architectures across five families: State Space Models (Mambular, MambaTab, MambAttention), Transformers (FTTransformer, TabTransformer, SAINT, AutoInt), residual networks (ResNet, TabR), tree-inspired models (NODE, ENODE, NDTF), and general baselines (MLP, TabM, TabulaRNN). Three experimental models (ModernNCA, Tangos, Trompt) are under evaluation for promotion.
See the Model Zoo for detailed comparisons, complexity analysis, and selection guidance.
| Category | Model | Architecture | Best For |
|---|---|---|---|
| State Space Models | Mambular | Stacked Mamba over feature tokens | General-purpose tabular modeling |
| MambaTab | Lightweight Mamba SSM | Small datasets and fast training | |
| MambAttention | Mamba with feature attention | Feature-interaction-heavy data | |
| Transformers | FTTransformer | Feature Tokenizer + Transformer | Strong attention-based baseline |
| TabTransformer | Transformer over categorical tokens | Categorical-heavy data | |
| SAINT | Row and column attention | Small or label-scarce datasets | |
| AutoInt | Self-attentive feature interactions | Automatic high-order interactions | |
| Residual Networks | ResNet | Residual MLP | Fast dense baseline |
| TabR | Retrieval-augmented MLP/kNN | Large datasets with neighbor signal | |
| Tree-Inspired | NODE | Neural oblivious decision ensembles | Differentiable tree inductive bias |
| ENODE | Embedded NODE-style soft trees | Tree-inspired modeling with embeddings | |
| NDTF | Neural decision tree forest | Differentiable forest experiments | |
| Other | MLP | Feedforward dense network | Fastest baseline |
| TabM | Parameter-efficient ensemble MLP | Strong efficient baseline | |
| TabulaRNN | Recurrent feature-sequence model | Sequential feature modeling |
β οΈ API Not Stable: Experimental models may change in minor releases. Always pin exact version:deeptab==x.y.z
- ModernNCA: Neighborhood Component Analysis (metric learning)
- Tangos: Gradient orthogonalization approach
- Trompt: Prompt-based learning for tabular data
All models come in three variants:
*Classifier: Classification (binary & multi-class)*Regressor: Regression (point estimates)*LSS: Distributional regression (full distribution prediction)
Consistent API: All models use the same interface, so you can swap architectures without changing code.
Full documentation: deeptab.readthedocs.io
- Getting Started: Installation, quickstart, FAQ
- Core Concepts: sklearn API, config system, preprocessing, training
- Tutorials: Classification, regression, LSS (with Google Colab)
- Model Zoo: Model selection, comparisons, recommended configs
- API Reference: Complete API documentation
Basic installation:
pip install deeptabWith experiment tracking and structured logging:
pip install 'deeptab[tracking]' # MLflow + TensorBoard loggers
pip install 'deeptab[logs]' # structured logging via structlog
pip install 'deeptab[all]' # every optional backendFaster Mamba models (optional CUDA kernels):
pip install mamba-ssmMamba kernels are optional: They give a 20-30% speedup for Mamba-based models on a compatible NVIDIA GPU (CUDA 11.6+). If the install fails or no GPU is present, DeepTab falls back to a pure-PyTorch implementation automatically.
Lightweight by default: Tracking backends are optional and imported lazily, so a plain
pip install deeptabstays small. Install only the extras you actually use.
Requirements: Python 3.10+, PyTorch 2.2+, Lightning 2.3.3+
GPU Support: See installation guide for CUDA setup.
from deeptab.models import MambularClassifier
from deeptab.configs import MambularConfig, PreprocessingConfig, TrainerConfig
# 1. Initialize with configuration (optional - defaults work well!)
model_config = MambularConfig(d_model=64, n_layers=6)
prep_config = PreprocessingConfig(numerical_preprocessing="quantile")
trainer_config = TrainerConfig(lr=1e-4, batch_size=256)
model = MambularClassifier(
model_config=model_config,
preprocessing_config=prep_config,
trainer_config=trainer_config
)
# 2. Fit (X can be pandas DataFrame or numpy array)
model.fit(X_train, y_train, max_epochs=50)
# 3. Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)
# 4. Evaluate
metrics = model.evaluate(X_test, y_test)
# Regression: {"rmse": β¦, "mae": β¦, "r2": β¦}
# Classification: {"accuracy": β¦, "auroc": β¦, "log_loss": β¦}
# LSS (normal): {"crps": β¦, "rmse": β¦, "mae": β¦}π‘ Tip: Start with defaults (
MambularClassifier()) and tune only if needed. See Recommended Configs for guidance.
DeepTab models are sklearn-compatible, so you can use GridSearchCV:
from sklearn.model_selection import GridSearchCV
from deeptab.models import MambularClassifier
param_grid = {
"model_config__d_model": [64, 128, 256],
"model_config__n_layers": [4, 6, 8],
"trainer_config__lr": [1e-4, 5e-4, 1e-3],
}
search = GridSearchCV(
MambularClassifier(),
param_grid,
cv=5,
scoring="accuracy"
)
search.fit(X_train, y_train)
print(f"Best params: {search.best_params_}")
print(f"Best score: {search.best_score_}")Built-in HPO: Every estimator exposes
optimize_hparams(), which runs Gaussian process Bayesian optimization (via scikit-optimize) over a search space derived from the model config. See the HPO Tutorial.
Predict a full distribution instead of a single point estimate:
from deeptab.models import MambularLSS
# Choose a distribution family when you fit
model = MambularLSS()
model.fit(X_train, y_train, family="normal", max_epochs=50)
# predict() returns the estimated distribution parameters per sample
# (for "normal", that is the location and scale)
params = model.predict(X_test)
# Evaluate with proper scoring rules selected for the family
metrics = model.evaluate(X_test, y_test)Available families:
normal,lognormal,studentt,gamma,beta,tweedie,poisson,zip,negativebinom,dirichlet,mog,quantile, and more. Each family auto-selects appropriate evaluation metrics (CRPS, deviances, NLL).
Prediction intervals: Turn the predicted parameters into calibrated intervals as shown in the Uncertainty Quantification tutorial.
DeepTab includes comprehensive preprocessing powered by PreTab:
from deeptab.configs import PreprocessingConfig
from deeptab.models import MambularClassifier
prep_config = PreprocessingConfig(
numerical_preprocessing="ple", # Piecewise linear encoding
n_bins=50 # Number of bins for the encoding
)
model = MambularClassifier(preprocessing_config=prep_config)
model.fit(X_train, y_train, max_epochs=50)Features:
- Automatic detection: Feature types detected from data
- Type-aware: Separate strategies for numerical and categorical features
- Methods: PLE, quantile transform, splines, standardization, min-max, and robust scaling
- Pre-trained encodings: Transfer learning for categorical features
Learn more: Preprocessing is driven by
PreprocessingConfig; see the Config System guide and the PreTab project.
DeepTab can record what happens during training without you writing any callbacks. Pass an ObservabilityConfig when you build a model, and each run captures its hyperparameters, lifecycle events, and final metrics in one self-contained folder.
from deeptab.core.observability import ObservabilityConfig
from deeptab.models import MambularClassifier
obs = ObservabilityConfig(
experiment_name="churn_baseline",
structured_logging=True, # human-readable console + JSON event log
experiment_trackers=["mlflow"], # also supports "tensorboard"
)
model = MambularClassifier(observability_config=obs)
model.fit(X_train, y_train, max_epochs=50)Every fit produces a tidy, reproducible run directory:
deeptab_runs/
runs/churn_baseline/20260611_174830_8f3a2c/
config.yaml # estimator hyperparameters
lifecycle.jsonl # structured event log
summary.json # final metrics
checkpoints/best.ckpt
tensorboard/...
mlflow/...
Tune the noise:
verbositycontrols how much is emitted (0silent,1milestones,2detailed,3debug). The default keeps notebooks quiet.
π¬ For researchers: Lifecycle events such as
fit.started,model.created, andtrain.completedcarry structured metadata (sample counts, parameter counts, best validation loss), so you can script experiment sweeps and compare runs programmatically.
π Learn more: Observability
Implement your own architecture with DeepTab's base classes. A model is three
small pieces: a dataclass config (subclassing BaseModelConfig), a PyTorch
architecture (subclassing BaseModel), and one estimator per task that
binds them via _model_cls / _config_cls:
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from deeptab.configs import BaseModelConfig, TrainerConfig
from deeptab.core import BaseModel, get_feature_dimensions
from deeptab.models import SklearnBaseRegressor
@dataclass
class MyCustomConfig(BaseModelConfig):
layer_sizes: list = field(default_factory=lambda: [128, 64])
dropout: float = 0.1
class MyCustomModel(BaseModel):
def __init__(
self,
feature_information: tuple, # (num_info, cat_info, embedding_info)
num_classes: int = 1,
config: MyCustomConfig = MyCustomConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["feature_information"])
# Input width is derived from the data, never hard-coded.
input_dim = get_feature_dimensions(*feature_information)
layers: list[nn.Module] = []
prev = input_dim
for size in self.hparams.layer_sizes:
layers += [nn.Linear(prev, size), nn.ReLU(), nn.Dropout(self.hparams.dropout)]
prev = size
layers.append(nn.Linear(prev, num_classes))
self.layers = nn.Sequential(*layers)
def forward(self, *data) -> torch.Tensor:
# data == (num_features, cat_features, embeddings)
x = torch.cat([t for group in data for t in group], dim=1)
return self.layers(x)
class MyRegressor(SklearnBaseRegressor):
_model_cls = MyCustomModel
_config_cls = MyCustomConfig
# Use like any other DeepTab model
model = MyRegressor(
model_config=MyCustomConfig(layer_sizes=[256, 128]),
trainer_config=TrainerConfig(lr=1e-3),
)
model.fit(X_train, y_train, max_epochs=50)π Learn more: Custom Models walks through configs, embeddings, and the
*Classifier/*Regressor/*LSSvariants.
π οΈ Developer Guide: See Contributing for architecture guidelines.
If you use DeepTab in your research, please cite:
@article{thielmann2024mambular,
title={Mambular: A Sequential Model for Tabular Deep Learning},
author={Thielmann, Anton Frederik and Kumar, Manish and Weisser, Christoph and Reuter, Arik and S{\"a}fken, Benjamin and Samiee, Soheila},
journal={arXiv preprint arXiv:2408.06291},
year={2024}
}
@article{thielmann2024efficiency,
title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
author={Thielmann, Anton Frederik and Samiee, Soheila},
journal={arXiv preprint arXiv:2411.17207},
year={2024}
}DeepTab is licensed under the MIT License. See LICENSE for details.
Contributions are welcome. See the Contributing Guide to get started, and please follow our Code of Conduct.
- Issues: GitHub Issues
- Discussions: GitHub Discussions
