人工智能大数据,工作效率生产力
Ctrl + D 收藏本站,更多好用AI工具
当前位置:首页 » AI资讯

AI框架的第三条路:面向对象+函数式融合编程

2024-05-15 80

写在开头——这是一篇硬核的“软文”。

AI框架发展至今,从功能上已经逐渐趋同,无外乎自动微分、Layer封装、并行训练、推理部署等几乎可称之为显学的核心功能,以及在基础框架之上生长的各类二次开发领域框架(套件)。关注AI框架演进的大概有两拨人,一波专攻算法,要的是能够快速实现算法的自由灵活,另一波专攻系统,为的是大规模分布式并行及模型部署落地。作为一个library,从功能、性能视角来对其探讨分析的不胜枚举,而这篇文章尝试以一个新的视角去看待AI框架的演进——编程范式(Programming paradigm)。

优雅易用的面向对象编程

Pytorch自诞生以来,凭借着易用性成为科研人员的首选,隐有业界标准的态势。Pytorch选择面向对象编程(Object Oriented Programming)的编程范式,不论是Tensor对象还是封装后的Module对象,分别能够简化反向传播和神经网络构造的复杂度。同时OOP也是所有开发人员最易理解和使用的编程范式,

在一般的编程场景中,代码(code)和数据(data)是两个核心构成部分[1]。面向对象编程是针对特定对象(Object)来设计数据结构,定义类(Class)。类通常由以下两部分构成,分别对应了code和data:

  • 方法(Methods)
  • 属性(Attributes)

对于同一个Class实例化(instantiation)后得到的不同对象而言,methods和attributes相同,不同的是attributions的值。不同的属性值决定了对象的内部状态,因此OOP能够很好的进行状态管理。

使用Python简单构造一个类:

class Sample: #class declaration def __init__(self, name): # class constructor (code) self.name = name # attribute (data) def set_name(self, name): # method declaration (code) self.name = name # method implementation (code)

再来看Pytorch的面向对象编程:

class Linear(nn.Module): def __init__(self, in_features, out_features, has_bias): # class constructor (code) super().__init__() self.weight = nn.Parameter([out_features, in_features]) # layer weight (data) self.bias = nn.Parameter([out_features]) # layer weight (data) def forward(self, inputs): # method declaration (code) output = torch.matmul(inputs, self.weight.transpose(0, 1)) # tensor transformation (code) output = output + self.bias # tensor transformation (code) return output

对于构造神经网络来说,首要的组件就是网络层(Layer),一个神经网络层包含以下部分:

  • Tensor操作 (Operation)
  • 权重 (Weights)

此二者恰好与类的methods和attributes一一对应,同时权重本身就是神经网络层的内部状态,因此使用类来构造Layer天然符合其定义。此外,我们在编程时希望使用神经网络层进行堆叠,构造深度神经网络,使用OOP编程可以很容易的通过Layer对象组合构造新的Layer类。

可以说OOP是完美契合神经网络建模的编程范式,正因如此,无论是新手还是用老框架的研究者们使用Pytorch时感受到的不可名状的舒适感,也就可以想见了。

科学计算和函数式编程

科学计算和函数式编程并非强相关,只因Jax这个定位于Numerical Computing,使用于Scientific Computing的框架选择纯函数式的编程范式,因此放到一起介绍。

在函数式编程中,函数被视为一等公民,这意味着它们可以绑定到名称(包括本地标识符),作为参数传递,并从其他函数返回,就像任何其他数据类型一样。这允许以声明性和可组合的风格编写程序,其中小功能以模块化方式组合。函数式编程有时被视为纯函数式编程的同义词,是将所有函数视为确定性数学函数或纯函数的函数式编程的一个子集。当使用一些给定参数调用纯函数时,它将始终返回相同的结果,并且不受任何可变状态或其他副作用的影响。[2]

函数式编程有两个核心特点,使其十分符合科学计算的需要:

  1. 编程函数语义与数学函数语义完全对等。
  2. 确定性,给定相同输入必然返回相同输出。无副作用。

由于确定性这一特点,纯函数式编程的拥趸者们笃信通过限制副作用,程序可以有更少的错误,更容易调试和测试,更适合形式验证。但是纯函数式编程对开发者要求极高,且思维模式转变的成本大于减少出错带来的收益,使得函数式编程一度成为学术界的玩具,只流行于编程语言理论领域的研究中。

回过头来看Jax,其名称即是定位:Autograd and XLA。Jax其实并非一个深度学习框架,因为它要更低一层,只做数值计算加速和自动微分。也正是由于有自动微分这一反向传播必需的基础能力,因此很多时候也将其划到了深度学习框架的范畴内。

从Jax的定位来看,其选择纯函数式编程范式的原因也就明了了——不需要考虑一个通用编程语言那么多使用场景的兼顾,同时数值计算这一场景本身就应当满足数学语义。再看一下AssemblyAI的blog里提到的Jax函数式编程三点限制[3]

  • It cannot change the state of the program by accessing or assigning variables outside its scope
  • It cannot have an I/O stream – so no printing, asking for input, or accessing the time
  • It cannot have a mutable function as an argument (which a concurrent process could modify)

由于移除了所有的状态控制,因此函数内部是完全不可修改的。此外,Jax也无法使用像Pytorch一样的inplace操作。从上述三点来看,纯函数式依旧不可避免地增加了学习曲线的陡峭程度。在使用Jax时,需要摒弃OOP的思维来思考如何实现自己想要的功能,并习惯函数变换(输入函数,返回新函数)。

一个简单的Jax样例:

grad_tanh = grad(jnp.tanh) print(grad_tanh(2.0)) # 0.070650816 print(grad(grad(jnp.tanh))(2.0)) print(grad(grad(grad(jnp.tanh)))(2.0)) # -0.13621868 # 0.25265405

但反过来考虑,做科学计算乃至数值计算的人,未见得学过编程,没有OOP思维的先入为主,上述案例的高阶微分操作反而更符合一直以来学习数学所养成的习惯。只是从AI模型构造的角度,对一众程序员和算法工程师而言相对较难。

MindSpore选择的第三条路

通过消灭可变状态来达到无副作用的目的虽然完全符合数学语义上的函数,但是理解和编程难度陡升。正因如此,纯函数式编程的语言几乎没有能够成为主流开发语言,而现代编程语言几乎不约而同的选择了接纳函数式编程特性,但不全盘all in函数式的做法。

他山之石,可以攻玉。MindSpore选择和编程语言们一样的发展路径,将函数式编程和面向对象编程融合,兼顾用户习惯和新赛道需要,提供易用性最好,编程体验最佳的融合编程范式。

纯函数式/面向对象的问题

有了前文的介绍,可以很容易推断出纯函数式/面向对象带来的问题。对于Jax或这样纯函数式编程范式的框架而言,虽然框架支持自动微分,但是由于缺少了状态管理,对于习惯了Pytorch这样的面向对象编程习惯而言很难适应:

  1. 没有Layer概念,神经网络由多个函数变化构成。
  2. 权重和数据平等看待,都作为函数输入,反而难对其进行区分。
  3. 函数具有确定性,但需要依赖外部变量/对外部变量造成副作用的场景非常常见。

为了解决1和2,基于Jax之上封装了像Haiku、Flax这样的套件,虽然能够模拟一部分OOP的特性,但终究逃不过第三点的制约。如同纯函数式编程语言的困境一样,Jax没有快速流行并替代Tensorflow的原因大概也是相同的——程序员接受度太低。

而纯面向对象的Pytorch作为一个AI框架本身是很难挑出什么毛病的,但由于Scientific Computing的火热(也许是AI泡沫着急转移),总要在新的赛道上分一杯羹,这时候为了支持Jax这样的函数式表达就有些捉襟见肘:

  1. torch.autograd和函数式的grad无法统一。
  2. 以Tensor为中心的反向传播接口和函数式接口完全割裂。

因此,Pytorch选择了和Jax相同的思路,基于框架进行二次封装,做了functorch,从接口形态上完全模拟Jax的·函数变换。但是由于本身Module对象无法作为函数使用,因此有些多此一举地做了make_functional接口,先将Module中的权重取出并构造一个FunctionalModule,然后再原封不动的塞回去。 原本以Pythonic著称的Pytorch也为了支持函数式而做了这样的妥协,显得不那么优雅。

既然纯函数式/面向对象都有其弊病,且来看看融合编程范式会是什么样。

MindSpore融合编程的优势

不论是Pytorch还是Jax,其发展思路都是保住基本盘而后用套件完成跨领域生态构建的策略。而MindSpore不同,自立项起,整体设计上就是同时面向AI和科学计算的(我愿称之为真正的全场景)。虽然老版本的GradOperation使用方式多少有些不便,但是框架架构上天然支持并一直采用半函数式做反向传播的功能,使得融合编程范式的切换几乎水到渠成。

其实道理也很简单,纯函数式会让学习曲线陡增,易用性变差;OOP构造神经网络的编程范式深入人心,Pytorch生态独霸学术圈;科学计算新赛道崛起,势不可挡。我们为的是让使用框架的人更便利而非增加困难,所以要做到以下几点:

  1. OOP方式构造神经网络。
  2. 一套自动微分机制同时用于科学计算和AI。
  3. 带副作用的半函数式,方便状态管理。
  4. 如果愿意,一样可以写纯函数式。

这几点做到了,大可以说MindSpore是一个Pytorch+Jax and beyond的框架了。这里简单做一个表格来对比一下三家的方案,孰优孰劣,一看便知。

Pytorch+functorch Jax+haiku/flax MindSpore
方案 AI-Centric
框架+套件
Numerical-Centric
框架+套件
DualCore(AI-Numerical)
框架
优点 对AI编程无影响 对数值计算编程无影响 AI和数值计算原生融合,共用同一套编程范式。
不足 1. 函数微分和基于Tensor的反向传播逻辑割裂。
2. 原有Module需要同API转换为function。
基于Jax纯函数式编程的特点,无法对layer对象进行有效管理。 /

杂交优势一览

Talk is cheap, show me the code.

前面分析了一通,直接上代码。

Pytorch:

# Class definition class Net(nn.Module): def __init__(self): ...... def forward(self, inputs): ...... # Object instantiation net = Net() # network loss_fn = nn.CrossEntropyLoss() # loss function optimizer = optim.Adam(net.parameters(), lr) # optimizer for i in range(epochs): for inputs, targets in dataset(): optimizer.zero_grad() logits = net(inputs) loss = loss_fn(logits, targets) loss.backward() # back propagation optimizer.step() # update gradient

MindSpore(≥1.8.1):

# Class definition class Net(nn.Cell): def __init__(self): ...... def construct(self, inputs): ...... # Object instantiation net = Net() # network loss_fn = nn.CrossEntropyLoss() # loss function optimizer = nn.Adam(net.trainable_params(), lr) # optimizer # define forward function def forword_fn(inputs, targets): logits = net(inputs) loss = loss_fn(logits, targets) return loss, logits # get grad function grad_fn = value_and_grad(forward_fn, None, optim.parameters, has_aux=True) # define train step function def train_step(inputs, targets): (loss, logits), grads = grad_fn(inputs, targets) # get values and gradients optimizer(grads) # update gradient return loss, logits for i in range(epochs): for inputs, targets in dataset(): loss = train_step(inputs, targets)

Jax(Flax+Optax):

# Class definition class Net(nn.Module): def __init__(self): ...... def __call__(self, inputs): ...... # define train step function def train_step(state, batch): # define forward function def forward_fn(params): logits = Net().apply({'params': params}, batch['inputs']) loss = optax.softmax_cross_entropy(logits, batch['targets']) return loss, logits # get grad function grad_fn = jax.value_and_grad(forward_fn, has_aux=True) # get values and gradients (loss, logits), grads = grad_fn(state.params) # update gradients state = state.apply_gradient(grads) return state, loss # create state state = create_train_state(init_rng, learning_rate, adam) for i in range(epochs): for batch in dataset: state, loss = train_step(state, batch) 

三个样例看下来,最直观的感受应该是MindSpore的代码比Pytorch略复杂,比Jax要简略很多,下面再逐一分析每个部分:

  1. 网络构造,满足面向对象编程习惯,MindSpore提供与Pytorch完全一致的Layer构造方式,因此该部分两者几乎完全相同。
  2. 前向计算和反向传播,MindSpore使用函数式,将前向计算构造成function,然后通过函数变换,获得grad function,最后通过执行grad function获得权重对应的梯度。
  3. 权重更新,这里三者差异最大,其中Pytorch通过step方法内置了获取梯度和更新权重的过程,仅需要一次方法调用即可;MindSpore则将Optimizer同样视为和Layer同样的组件,管理整个网络的权重,通过对象调用传入梯度进行更新;而Jax则最为复杂,将权重单独使用一个State对象进行管理,需要提前创建State,然后不断作为函数的输入输出,进行更新。

进一步分析其中每个部分的差异:

Pytorch Jax(Flax+Optax) MindSpore
Layer 对象,权重为属性 模拟对象,无权重 对象管理,权重为属性
Loss 对象,同Layer 函数 对象,同Layer
Optimizer 对象,管理状态 对象 对象,同Layer,
状态也为权重
State / 对象,管理权重 /
正向计算 直接调用对象 构造函数 构造函数
反向传播 Tensor类方法 反向函数 反向函数
权重更新 Optimizer类方法 State类方法 Optimizer直接调用

两棵决策树

前面也分析介绍了很多,最后也仿照AssemblyAI的blog做两个决策树,跟着问题回答Yes or No,我想大家一定能够选到最适合自己的框架。

AI框架的第三条路:面向对象+函数式融合编程插图
AI框架的第三条路:面向对象+函数式融合编程插图1

参考

  1. ^为什么Pytorch选择面向对象编程 https://phucnsp.github.io/blog/self-taught/2020/03/22/self-taught-pytorch-part1-tensor.html#Object-Oriented-Programming-and-Why-Pytorch-select-it
  2. ^https://vibaike.com/124588/
  3. ^2022年了,为什么你应该(不应该)用Jax https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022

原文链接:https://zhuanlan.zhihu.com/p/554357198

相关推荐

阅读榜

hellenandjeckett@outlook.com

加入QQ群:849112589

回顶部