谷歌发布JAX-Privacy 1.0:大规模差分隐私机器学习工具库

谷歌正式发布JAX-Privacy 1.0,这是基于高性能计算库JAX构建的差分隐私机器学习工具包。该库集成了最新研究成果,采用模块化设计,使研究人员和开发者能够更轻松地构建差分隐私训练管道。JAX-Privacy提供梯度裁剪、噪声生成、批量选择等核心组件,支持大规模分布式训练,已成功应用于VaultGemma等先进模型的训练中。

我们宣布发布JAX-Privacy 1.0,这是一个基于高性能计算库JAX构建的差分隐私机器学习工具库。

从个性化推荐到科学进步,AI模型正在帮助改善生活并改变各个行业。但这些AI模型的影响力和准确性往往取决于其使用的数据质量。大规模、高质量的数据集对于开发准确且具有代表性的AI模型至关重要,但必须以保护个人隐私的方式使用这些数据。

这就是JAX和JAX-Privacy发挥作用的地方。JAX于2020年推出,是一个专为大规模机器学习设计的高性能数值计算库。其核心功能包括自动微分、即时编译和跨多个加速器的无缝扩展,使其成为高效构建和训练复杂模型的理想平台。JAX已成为推动AI边界发展的研究人员和工程师的基石。其周围的生态系统包括一套强大的领域特定库,包括简化神经网络架构实现的Flax,以及实现最先进优化器的Optax。

基于JAX构建的JAX-Privacy是一个强大的工具包,用于构建和审计差分隐私模型。它使研究人员和开发者能够快速高效地实现差分隐私算法,用于在大型数据集上训练深度学习模型,并提供将隐私训练集成到现代分布式训练工作流程所需的核心工具。JAX-Privacy的原始版本于2022年推出,旨在让外部研究人员能够重现和验证我们在隐私训练方面的一些进展。此后,它已发展成为Google内部研究团队整合其在差分隐私训练和审计算法方面新颖研究见解的中心。

今天,我们自豪地宣布发布JAX-Privacy 1.0。这个新版本整合了我们最新的研究进展,重新设计以实现模块化,使研究人员和开发者比以往任何时候都更容易构建差分隐私训练流水线,将最先进的差分隐私算法与JAX提供的可扩展性相结合。

多年来,研究人员一直将差分隐私作为量化和限制隐私泄露的黄金标准。差分隐私保证算法的输出几乎相同,无论数据集中是否包含单个个体或示例。

虽然差分隐私的理论已经成熟,但在大规模机器学习中的实际实现仍然是一个挑战。最常见的方法是差分隐私随机梯度下降,需要定制的批处理程序、逐例梯度裁剪以及添加精心校准的噪声。这个过程计算量大,很难正确高效地实现,特别是在现代基础模型的规模下。

JAX-Privacy使研究人员和开发者能够使用最先进的差分隐私算法在私有数据上以可扩展和高效的方式训练和微调基础模型,这要归功于其用于梯度裁剪和相关噪声生成的原始构建块,这两者在分布式环境中都能有效工作。

现有框架已经取得了进展,但在可扩展性或灵活性方面往往不足。我们的工作一直在推动隐私机器学习的边界,从开创新的差分隐私算法到开发复杂的审计技术。我们需要一个能够跟上我们研究步伐的工具——一个不仅正确高效,而且从头开始设计来处理最先进模型的并行性和复杂性的库。

JAX的函数式范式和强大的变换,如vmap(用于自动向量化)和shard_map(用于单程序多数据并行化),提供了坚实的基础。通过基于JAX构建,我们可以创建一个开箱即用的并行就绪库,支持跨多个加速器和超级计算机训练大规模模型。JAX-Privacy是这一努力的成果,一个经过时间考验的库,已为内部生产集成提供支持,现在正与更广泛的社区分享。

JAX-Privacy通过提供一套精心设计的组件简化了差分隐私的复杂性:

JAX-Privacy实现了各种基础工具,用于裁剪、噪声添加、批选择、记账和审计,这些工具可以以各种方式组合来构建端到端的差分隐私训练计划。

JAX-Privacy最令人兴奋的方面之一是其实际应用。该库设计用于支持用于预训练和微调大语言模型的现代机器学习框架。一个值得注意的例子是我们最近在训练VaultGemma时使用JAX-Privacy构建块,这是世界上最强大的差分隐私大语言模型。

通过这次开源发布,我们希望使开发者能够通过流行的Keras框架仅用几行代码轻松微调大型模型。特别是,我们包含了微调Gemma系列模型的完整功能示例,这是由Google DeepMind基于Gemini构建的开放模型集合。这些示例展示了如何将JAX-Privacy应用于对话摘要和合成数据生成等任务,表明即使在使用最先进的模型时,该库也能提供最先进的结果。

通过简化差分隐私的集成,JAX-Privacy使开发者能够从头开始构建保护隐私的应用程序,无论他们是在为医疗应用微调聊天机器人还是为个性化财务建议开发模型。它降低了隐私保护机器学习的入门门槛,使强大、负责任的AI更易获得。

我们很兴奋与研究社区分享JAX-Privacy。这次发布是多年专注努力的结果,代表了对隐私保护机器学习领域的重大贡献。我们希望通过提供这些工具,能够推动新一波造福所有人的研究和创新。

我们将继续支持和开发该库,整合新的研究进展并响应社区的需求。我们期待看到您使用JAX-Privacy构建的成果。查看GitHub上的存储库或PIP包,立即开始训练保护隐私的机器学习模型。

JAX-Privacy包括以下贡献者的贡献:Leonard Berrada、Robert Stanforth、Brendan McMahan、Christopher A. Choquette-Choo、Galen Andrew、Mikhail Pravilov、Sahra Ghalebikesabi、Aneesh Pappu、Michael Reneer、Jamie Hayes、Vadym Doroshenko、Keith Rush、Dj Dvijotham、Zachary Charles、Peter Kairouz、Soham De、Samuel L. Smith、Judy Hanwen Shen。

Q&A

Q1:JAX-Privacy 1.0是什么?它解决了什么问题?

A:JAX-Privacy 1.0是谷歌基于JAX构建的差分隐私机器学习工具库。它解决了在大规模机器学习中实现差分隐私算法的挑战,让研究人员和开发者能够在保护个人隐私的前提下,使用大型数据集训练AI模型,特别是在训练和微调基础模型时。

Q2:差分隐私为什么重要?它如何保护隐私?

A:差分隐私是量化和限制隐私泄露的黄金标准。它保证算法的输出几乎相同,无论数据集中是否包含某个特定个体的数据。这意味着即使模型使用了你的数据进行训练,也无法从模型输出中推断出你的个人信息,从而在获得AI模型益处的同时保护个人隐私。

Q3:JAX-Privacy与现有框架相比有什么优势?

A:JAX-Privacy基于JAX的高性能计算能力,支持自动微分、即时编译和多加速器扩展,在可扩展性和灵活性方面超越现有框架。它提供开箱即用的并行处理能力,支持大规模模型的分布式训练,并已经在VaultGemma等世界级差分隐私大语言模型的训练中得到验证。

来源:Google

0赞

好文章,需要你的鼓励

2025

12/31

15:57

分享

点赞

邮件订阅