论文标题
实用的条件神经过程通过可牵引的依赖性预测
Practical Conditional Neural Processes Via Tractable Dependent Predictions
论文作者
论文摘要
条件神经过程(CNP; Garnelo等,2018a)是元学习模型,它利用深度学习的灵活性来产生良好的预测,并自然处理网格范围和缺失的数据。 CNPS缩放到大型数据集,并轻松训练。由于这些功能,CNP似乎非常适合环境科学或医疗保健的任务。不幸的是,CNP不会产生相关的预测,从而使它们从根本上不适合许多估计和决策任务。例如,预测热浪或洪水需要在时间和空间中对温度或降水的依赖性建模。建模输出依赖性的现有方法,例如神经过程(NP; Garnelo等,2018b)或FullConvGnP(Bruinsma等,2021),要么是复杂的训练或过于昂贵的。需要的是一种提供依赖预测的方法,但易于训练和计算障碍。在这项工作中,我们提出了一类新的神经过程模型,这些模型可以简单且可扩展,从而提供相关的预测并支持确切的最大似然训练。我们通过使用可逆输出转换来扩展提出的模型,以捕获非高斯输出分布。我们的模型可以用于需要相关功能样本的下游估计任务中。通过考虑输出依赖性,我们的模型在合成和真实数据的一系列实验上显示出改进的预测性能。
Conditional Neural Processes (CNPs; Garnelo et al., 2018a) are meta-learning models which leverage the flexibility of deep learning to produce well-calibrated predictions and naturally handle off-the-grid and missing data. CNPs scale to large datasets and train with ease. Due to these features, CNPs appear well-suited to tasks from environmental sciences or healthcare. Unfortunately, CNPs do not produce correlated predictions, making them fundamentally inappropriate for many estimation and decision making tasks. Predicting heat waves or floods, for example, requires modelling dependencies in temperature or precipitation over time and space. Existing approaches which model output dependencies, such as Neural Processes (NPs; Garnelo et al., 2018b) or the FullConvGNP (Bruinsma et al., 2021), are either complicated to train or prohibitively expensive. What is needed is an approach which provides dependent predictions, but is simple to train and computationally tractable. In this work, we present a new class of Neural Process models that make correlated predictions and support exact maximum likelihood training that is simple and scalable. We extend the proposed models by using invertible output transformations, to capture non-Gaussian output distributions. Our models can be used in downstream estimation tasks which require dependent function samples. By accounting for output dependencies, our models show improved predictive performance on a range of experiments with synthetic and real data.