课程 ID: 19254
描述:
话题概述:
在大规模搜索、广告、推荐(SAR)场景下,传统的分布式训练框架面临双重挑战:
1、扩展性瓶颈: 随着特征维度的爆发,传统的TFRA(TensorFlow Recommenders Addons)和Parameter Server GPU 架构在处理大规模稀疏 Embedding 时,常受限于显存容量与总线带宽,导致性能无法随算力线性扩展。并且PS架构由于底层抽象复杂,算法工程师在构建复杂模型时面临极高的开发门槛与调试成本。
2、范式演进: 生成式推荐(Generative Recommendation)的兴起,要求架构必须同时兼容传统 PS 模式与现代大模型的 3D 并行/FSDP 等技术,原有框架的“补丁式”改进已无法满足需求。由于搜推模型的结构变更频繁,开发者迫切需要一种既能享受 XLA 编译加速,又能灵活处理异构分布式计算/存储的下一代架构。
3、JAX 的生态错位: 虽然 JAX 凭借 XLA 的极致优化在 LLM 领域大放异彩,但其原生的 SPMD(单程序多数据) 模式在面对 SAR 场景中高度动态、非对称的 MPMD(多程序多数据) 需求(如巨大的 Embedding 分片与异步更新、超大规模稀疏专家)时,存在天然的架构屏障。
演讲题纲:
话题亮点:
本议题核心探讨如何利用 JAX 重构搜广推模型底层,并打破其原生的并行范式限制:
1、从 SPMD 到 MPMD 的跨越: 针对 JAX 难以处理大规模非对称任务的问题,引入高性能 UCX RPC 框架作为底层通信原语。通过构建 P2P 通信能力,在 JAX 体系内实现了高性能的异步pipeline分布式 Embedding 访问,打破了单一的并行模式。
2、有机结合编译优化与动态行为: 面对搜推任务、强化学习任务中的大量动态行为、Tensor动态shape,在框架内实现了尽可能的友好接入,设计只在静态shape处编译优化。
3、高性能用户态通信栈: 采用基于 C++ 26(std::execution) 的 UCX 封装,在不破坏 JAX 执行流的前提下,实现了低延迟、零拷贝的参数传输,为 MPMD 模式提供了坚实的通信基座。
4、生成式推荐的架构融合: 统一了传统稀疏模型与大模型训练技术(3D Parallelism / FSDP),支持在同一套 JAX 架构下完成生成式推荐模型的快速迭代。