论文标题

注意图表数据的增加可区分森林

Attention augmented differentiable forest for tabular data

论文作者

Chen, Yingshi

论文摘要

可区分的森林是具有完全不同性的决策树的合奏。它的简单树结构易于使用和解释。有了完全的不同性,它将通过基于梯度的优化方法在端到端学习框架中进行培训。在本文中,我们在可区分森林的框架中提出了树的注意区(TAB)。 Tab块有两个操作,挤压和调节。挤压操作将提取每棵树的特征。规范操作将学习这些树之间的非线性关系。因此,标签块将了解每棵树的重要性并调整其重量以提高准确性。我们在大型表格数据集上的实验表明,注意力增强的可区分森林将与梯度提升决策树(GBDT)获得可比的准确性,这是表格数据集的最新算法。在某些数据集上,我们的模型的精度比最佳GBDT LIB(LightGBM,Catboost和XGBoost)更高。可区分的森林模型支持批处理训练,批量尺寸远小于训练集的大小。因此,在较大的数据集中,其内存使用量远低于GBDT模型。源代码可在https://github.com/closest-git/quantumforest上找到。

Differentiable forest is an ensemble of decision trees with full differentiability. Its simple tree structure is easy to use and explain. With full differentiability, it would be trained in the end-to-end learning framework with gradient-based optimization method. In this paper, we propose tree attention block(TAB) in the framework of differentiable forest. TAB block has two operations, squeeze and regulate. The squeeze operation would extract the characteristic of each tree. The regulate operation would learn nonlinear relations between these trees. So TAB block would learn the importance of each tree and adjust its weight to improve accuracy. Our experiment on large tabular dataset shows attention augmented differentiable forest would get comparable accuracy with gradient boosted decision trees(GBDT), which is the state-of-the-art algorithm for tabular datasets. And on some datasets, our model has higher accuracy than best GBDT libs (LightGBM, Catboost, and XGBoost). Differentiable forest model supports batch training and batch size is much smaller than the size of training set. So on larger data sets, its memory usage is much lower than GBDT model. The source codes are available at https://github.com/closest-git/QuantumForest.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源