Pytorch乘法算子的实现是怎样,有几种方式
Admin 2022-08-13 群英技术资讯 769 次浏览
这篇文章主要介绍了Pytorch乘法算子的实现是怎样,有几种方式相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch乘法算子的实现是怎样,有几种方式文章都会有所收获,下面我们一起来看看吧。pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。
先把 pytorch 中的一些常用的乘法运算进行一个总览:
如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。
先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):
>>> import torch >>> # torch.mm >>> a = torch.randn(66, 99) >>> b = torch.randn(99, 88) >>> c = torch.mm(a, b) >>> c.shape torch.size([66, 88]) >>> >>> # torch.bmm >>> a = torch.randn(3, 66, 99) >>> b = torch.randn(3, 99, 77) >>> c = torch.bmm(a, b) >>> c.shape torch.size([3, 66, 77]) >>> >>> # torch.mv >>> a = torch.randn(66, 99) >>> b = torch.randn(99) >>> c = torch.mv(a, b) >>> c.shape torch.size([66]) >>> >>> # torch.matmul >>> a = torch.randn(32, 3, 66, 99) >>> b = torch.randn(32, 3, 99, 55) >>> c = torch.matmul(a, b) >>> c.shape torch.size([32, 3, 66, 55]) >>> >>> # @ >>> d = a @ b >>> d.shape torch.size([32, 3, 66, 55])
来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:
//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}
可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:
// 构造张量 Tensor0 nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0); // 构造张量 Tensor1 nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1); // 添加矩阵乘法 nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type); // 获取输出 matmulOutput = Matmul_layer->getOputput(0);
再来看看点乘的 pytorch 的实现 (以下实现在终端):
>>> import torch >>> # torch.mul >>> a = torch.randn(66, 99) >>> b = torch.randn(66, 99) >>> c = torch.mul(a, b) >>> c.shape torch.size([66, 99]) >>> d = 0.125 >>> e = torch.mul(a, d) >>> e.shape torch.size([66, 99]) >>> # * >>> f = a * b >>> f.shape torch.size([66, 99])
来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:
//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//! This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//! and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
return mImpl->addScale(input, mode, shift, scale, power);
}
可以看到有三个模式:
再来看这个算子在 TensorRT 中怎么添加:
// 构造张量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);
// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;
// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };
// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行
// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);
// 获取输出
scaleOutput = Scale_layer->getOputput(0);
有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
由于字符串数据几乎无处不在,因此掌握有关字符串的交易工具非常重要。幸运的是,Python 使字符串操作变得非常简单,尤其是与其他语言甚至旧版本的 Python 相比时。本文将为大家详细介绍Python中字符串的拆分与连接,需要的可以参考一下
这篇文章主要为大家介绍了PyTorch搭建双向LSTM实现时间序列负荷预测,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
大家好,本篇文章主要讲的是python绘制超炫酷动态Julia集示例,感兴趣的痛学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
这篇文章主要介绍了python如何实现数组元素两两相加,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
这篇文章主要为大家介绍了python密码学换位密码及换位解密转置加密教程,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008