CANN/asc-devkit反量化API

张开发
2026/5/9 13:24:17 15 分钟阅读

分享文章

CANN/asc-devkit反量化API
AscendDequant【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品 / Atlas A3 推理系列产品√Atlas A2 训练系列产品 / Atlas A2 推理系列产品√Kirin X90√Kirin 9030√功能说明按元素做反量化计算比如将int32_t数据类型反量化为half/float等数据类型。本接口最多支持输入为二维数据不支持更高维度的输入。假设输入srcTensor的shape为**m, n每行数据即n个输入数据所占字节数要求32字节对齐**每行中进行反量化的元素个数为calCount反量化系数deqScale可以为标量或者向量为向量的情况下calCount deqScale的元素个数只有前CalCount个反量化系数生效输出dstTensor的shape为**m, n_dst n * sizeof(dstT)不满足32字节对齐时需要向上补齐为32字节**n_dst为向上补齐后的列数。下面通过两个具体的示例来解释参数的配置和计算逻辑下文中DequantParams类型为存储shape信息的结构体{m, n, calCount}如下图示例中srcTensor的数据类型为int32_tm 4n 8calCount 4表明srcTensor中每行进行反量化的元素个数为4deqScale中的前4个数生效后12个数不参与反量化计算dstTensor的数据类型为bfloat16_tm 4n_dst 16 (16 * sizeof(bfloat16_t) % 32 0)。计算逻辑是srcTensor的每n个数为一行对于每行中的前calCount个元素该行srcTensor的第i个元素与deqScale的第i个元素进行相乘写入dstTensor对应行的第i个元素dstTensor对应行的第calCount 1个元素~第n_dst个元素均为不确定的值。如下示例中srcTensor的数据类型为int32_tm 4n 8 calCount 4表明srcTensor中每行进行反量化的元素个数为4dstTensor的数据类型为floatm 4n_dst 8 (8 * sizeof(float) % 32 0)。对于srcTensor每行中的前4个元素都和标量deqScale相乘并写入dstTensor中每行的对应位置。当用户将模板参数中的mode配置为DEQUANT_WITH_SINGLE_ROW时针对DequantParams {m, n, calCount} 若同时满足以下3个条件m 1calCount为 32 / sizeof(dstT)的倍数n % calCount 0此时 {1, n, calCount}会被视作为** {n / calCount, calCount, calCount}** 进行反量化的计算。具体效果可看下图所示传入的DequantParams为 {1, 16, 8}。因为dstT为float所以calCount满足为8的倍数在DEQUANT_WITH_SINGLE_ROW模式下会将{1, 2 * 8, 8}转换为 {2, 8, 8}进行计算。PER_TOKEN反量化srcTensor的每组tokentoken为n方向共有m组token中的元素共享一组deqscale参数srcTensor为[m, n]时deqscale为[m, 1]。PER_GROUP反量化这里定义group的计算方向为k方向srcTensor在k方向上每groupSize个元素共享一组deqscale参数。srcTensor为[m, n]时如果kDim0表示k是m方向deqscale为[(m groupSize - 1) / groupSize, n]如果kDim1表示k是n方向deqscale的shape为[m(n groupSize - 1) / groupSize]。kDim0kDim1实现原理以数据类型int32_tshape为[m, n]的输入srcTensor数据类型scaleTshape为[n]的输入deqScale和数据类型dstTshape为[m, n]的输出dstTensor为例描述AscendDequant高阶API内部算法框图如下图所示。图 1AscendDequant内部算法框图![](https://raw.gitcode.com/cann/asc-devkit/raw/0f09e05bd80fac60c9606596c6011b45973c6a0b/docs/api/context/figures/AscendDequant内部算法框图.png AscendDequant内部算法框图?utm_sourcegitcode_repo_files)计算过程分为如下几步均在Vector上进行精度转换将srcTensor和deqScale都转换成FP32精度的tensor分别得到srcFP32和deqScaleFP32Mul计算srcFP32一共有m行每行长度为n通过m次循环将srcFP32的每行与deqScaleFP32相乘通过mask控制仅对前dequantParams.calcount个数进行mul计算图中index的取值范围为 [0, m)对应srcFP32的每一行计算所得结果为mulResshape为[m, n]结果数据精度转换mulRes从FP32转换成dstT类型的tensor所得结果为dstTensorshape为[m, n]。PER_TOKEN/PER_GROUP场景下输入srcTensor数据类型是int32_t/float此时内部算法框图如下所示。图 2AscendDequant PER_TOKEN/PER_GROUP内部算法框图![](https://raw.gitcode.com/cann/asc-devkit/raw/0f09e05bd80fac60c9606596c6011b45973c6a0b/docs/api/context/figures/AscendDequant-PER_TOKEN-PER_GROUP内部算法框图.png AscendDequant-PER_TOKEN-PER_GROUP内部算法框图?utm_sourcegitcode_repo_files)PER_TOKEN/PER_GROUP场景的计算逻辑如下读取数据连续读取输入srcTensor根据不同的场景对输入deqscale采用不同的读取方式例如PER_TOKEN场景做Broadcast处理PER_GROUP场景做Gather处理精度转换根据不同输入的数据类型组合对srcTensor/deqscale进行相应的数据类型转换计算对类型转换后的srcTensor和deqscale数据做乘法精度转换将上述计算得到的结果转换成dstT类型得到最终输出。函数原型反量化参数deqScale为矢量通过sharedTmpBuffer入参传入临时空间template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale, const LocalTensoruint8_t sharedTmpBuffer, DequantParams params)接口框架申请临时空间template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale, DequantParams params)PER_TOKEN/PER_GROUP量化通过sharedTmpBuffer入参传入临时空间template typename dstT, typename srcT, typename scaleT, const AscendDeQuantConfig config, const AscendDeQuantPolicy policy __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorsrcT srcTensor, const LocalTensorscaleT scaleTensor, const LocalTensorscaleT offsetTensor, const LocalTensoruint8_t sharedTmpBuffer, const AscendDeQuantParam para)接口框架申请临时空间template typename dstT, typename srcT, typename scaleT, const AscendDeQuantConfig config, const AscendDeQuantPolicy policy __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorsrcT srcTensor, const LocalTensorscaleT scaleTensor, const LocalTensorscaleT offsetTensor, const AscendDeQuantParam para)反量化参数deqScale为标量通过sharedTmpBuffer入参传入临时空间template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const scaleT deqScale, const LocalTensoruint8_t sharedTmpBuffer, DequantParams params)接口框架申请临时空间template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const scaleT deqScale, DequantParams params)由于该接口的内部实现中涉及复杂的数学计算需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。接口框架申请临时空间开发者无需申请但是需要预留临时空间的大小。通过sharedTmpBuffer入参传入使用该tensor作为临时空间进行处理接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间并在接口调用完成后复用该部分内存内存不会反复申请释放灵活性较高内存利用率也较高。接口框架申请的方式开发者需要预留临时空间通过sharedTmpBuffer传入的情况开发者需要为sharedTmpBuffer申请空间。临时空间大小BufferSize的获取方式如下通过GetAscendDequantMaxMinTmpSize中提供的GetAscendDequantMaxMinTmpSize接口获取需要预留空间的范围大小。以下接口不推荐使用新开发内容不要使用如下接口template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale, const LocalTensoruint8_t sharedTmpBuffer, const uint32_t calCount)template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale, const LocalTensoruint8_t sharedTmpBuffer)template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale, const uint32_t calCount)template typename dstT, typename scaleT, DeQuantMode mode DeQuantMode::DEQUANT_WITH_SINGLE_ROW __aicore__ inline void AscendDequant(const LocalTensordstT dstTensor, const LocalTensorint32_t srcTensor, const LocalTensorscaleT deqScale)参数说明表 1模板参数说明参数名描述dstT目的操作数的数据类型。scaleTdeqScale的数据类型。mode决定当DequantParams为{1, n, calCount}时的计算逻辑传入enum DeQuantMode支持以下 2 种配置DEQUANT_WITH_SINGLE_ROW当DequantParams {m, n, calCount} 同时满足以下条件1、m 12、calCount为 32 / sizeof(dstT)的倍数3、n % calCount 0时即 {1, n, calCount} 会当作 {n / calCount, calCount, calCount} 进行计算。DEQUANT_WITH_MULTI_ROW即使满足上述所有条件{1, n, calCount} 依然只会当作 {1, n, calCount} 进行计算 即总共n个数前calCount个数进行反量化的计算。表 2PER_TOKEN/PER_GROUP场景模板参数说明参数名描述srcT源操作数的数据类型。config量化接口配置参数AscendDeQuantConfig类型具体定义如下。struct AscendDeQuantConfig { bool hasOffset; int32_t kDim 1; }hasOffset量化参数offset是否参与计算。True表示offset参数参与计算。False表示offset参数不参与计算。kDimgroup的计算方向即k方向。仅在PER_GROUP场景有效支持的取值如下。0k轴是第0轴即m方向为group的计算方向1k轴是第1轴即n方向为group的计算方向。policy量化策略配置参数枚举类型可取值如下enum class AscendDeQuantPolicy : int32_t { PER_TOKEN, // 配置为PER_TOKEN模式 PER_GROUP, // 配置为PER_GROUP模式 PER_CHANNEL_PER_GROUP, // 预留参数暂不支持 PER_TOEKN_PER_GROUP // 预留参数暂不支持 }表 3接口参数说明参数名输入/输出描述dstTensor输出目的操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型为half、bfloat16_t、float。Atlas A3 训练系列产品 / Atlas A3 推理系列产品支持的数据类型为half、bfloat16_t、float。Atlas A2 训练系列产品 / Atlas A2 推理系列产品支持的数据类型为half、bfloat16_t、float。Kirin X90支持的数据类型为half、float。Kirin 9030支持的数据类型为half、float。dstTensor的行数和srcTensor的行数保持一致。n * sizeof(dstT)不满足32字节对齐时需要向上补齐为32字节n_dst为向上补齐后的列数。如srcTensor数据类型为int32_tshape为 (4, 8)dstTensor为bfloat16_t则n_dst应从8补齐为16dstTensor shape为(4, 16)。补齐的计算过程为n_dst (8 * sizeof(bfloat16_t) 32 - 1) / 32 * 32 / sizeof(bfloat16_t)。srcTensor输入源操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型为int32_t。Atlas A3 训练系列产品 / Atlas A3 推理系列产品支持的数据类型为int32_t。Atlas A2 训练系列产品 / Atlas A2 推理系列产品支持的数据类型为int32_t。Kirin X90支持的数据类型为int32_t。Kirin 9030支持的数据类型为int32_t。shape为 [m, n]n个输入数据所占字节数要求32字节对齐。deqScale输入源操作数。类型为标量或者LocalTensor。类型为LocalTensor时支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT当deqScale为矢量时支持的数据类型为uint64_t、float、bfloat16_t当deqScale为标量时支持的数据类型为bfloat16_t、float。Atlas A3 训练系列产品 / Atlas A3 推理系列产品当deqScale为矢量时支持的数据类型为uint64_t、float、bfloat16_t当deqScale为标量时支持的数据类型为bfloat16_t、float。Atlas A2 训练系列产品 / Atlas A2 推理系列产品当deqScale为矢量时支持的数据类型为uint64_t、float、bfloat16_t当deqScale为标量时支持的数据类型为bfloat16_t、float。dstTensor、srcTensor、deqScale支持的数据类型组合请参考表5和表6。Kirin X90当deqScale为矢量时支持的数据类型为uint64_t、float当deqScale为标量时支持的数据类型为float。Kirin 9030当deqScale为矢量时支持的数据类型为uint64_t、float当deqScale为标量时支持的数据类型为float。sharedTmpBuffer输入临时缓存。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。临时空间大小BufferSize的获取方式请参考GetAscendDequantMaxMinTmpSize。Ascend 950PR/Ascend 950DT支持的数据类型为uint8_t。Atlas A3 训练系列产品 / Atlas A3 推理系列产品支持的数据类型为uint8_t。Atlas A2 训练系列产品 / Atlas A2 推理系列产品支持的数据类型为uint8_t。Kirin X90支持的数据类型为uint8_t。Kirin 9030支持的数据类型为uint8_t。params输入srcTensor的shape信息。DequantParams类型具体定义如下struct DequantParams { uint32_t m; // srcTensor的行数 uint32_t n; // srcTensor的列数 uint32_t calCount; // 针对srcTensor每一行前calCount个数为有效数据与deqScale的前calCount个数或者deqScale标量进行乘法计算 };DequantParams.n * sizeof(T)必须是32字节的整数倍T为srcTensor中元素的数据类型。因为是每n个数中的前calCount个数进行乘法运算因此DequantParams.n和calCount需要满足以下关系1 DequantParams.calCount DequantParams.n。deqScale为矢量时DequantParams.calCount deqScale的元素个数。表 4PER_TOKEN/PER_GROUP场景接口参数说明参数名输入/输出描述dstTensor输出目的操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型为half、bfloat16_t、float。srcTensor输入源操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型为int32_t、float。sharedTmpBuffer输入临时缓存。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。临时空间大小BufferSize的获取方式请参考GetAscendQuantMaxMinTmpSize。Ascend 950PR/Ascend 950DT支持的数据类型为uint8_t。scaleTensor输入量化参数scale。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型为half、bfloat16_t、float。offsetTensor输入量化参数offset。预留参数当前暂不支持。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT。Ascend 950PR/Ascend 950DT支持的数据类型和scaleTensor保持一致。para输入反量化接口的参数定义如下struct AscendDeQuantParam { uint32_t m; uint32_t n; uint32_t calCount; uint32_t groupSize 0; }mm方向元素个数。nn方向元素个数。n值对应的数据大小需满足32B对齐的要求即shape最后一维为n的输入输出均需要满足该维度上32B对齐的要求。calCount:参与计算的元素个数。calCount必须是n的整数倍。groupSize PER_GROUP场景有效表示groupSize行/列数据共用一个scale/offset。groupSize的取值必须大于0且是32的整倍数。表 5支持的数据类型组合deqScale为LocalTensordstTensorsrcTensordeqScalehalfint32_tuint64_t注意当deqScale的数据类型是uint64_t时数值低32位是参与计算的数据数据类型是float数值高32位是一些控制参数本接口不使用。floatint32_tfloatfloatint32_tbfloat16_tbfloat16_tint32_tbfloat16_tbfloat16_tint32_tfloat表 6支持的数据类型组合deqScale为标量dstTensorsrcTensordeqScalebfloat16_tint32_tbfloat16_tbfloat16_tint32_tfloatfloatint32_tbfloat16_tfloatint32_tfloat表 7PER_TOKEN/PER_GROUP场景支持的数据类型组合srcDtypescaleDtypedstDtypeint32_thalfhalfbfloat16_tbfloat16_tfloatfloatfloathalffloatbfloat16_tfloathalfhalfbfloat16_tbfloat16_tfloatfloatfloathalffloatbfloat16_t返回值说明无约束说明不支持源操作数与目的操作数地址重叠。操作数地址对齐要求请参见通用地址对齐约束。PER_TOKEN/PER_GROUP场景连续计算方向即n方向的数据量要求32B对齐。调用示例rowLen m; // m 4 colLen n; // n 8 //输入srcLocal的shape为4*8类型为int32_tdeqScaleLocal的shape为8类型为float预留临时空间 AscendC::AscendDequant(dstLocal, srcLocal, deqScaleLocal, {rowLen, colLen, deqScaleLocal.GetSize()});结果示例如下输入数据(srcLocal) int32_t数据类型: [ -8 5 -5 -7 -3 -8 3 6 9 2 -5 0 0 -5 -7 0 -6 0 -2 3 -2 8 5 2 2 2 -4 5 -4 4 -8 3 ] 反量化参数deqScale float数据类型: [ 10.433567 10.765296 -30.694275 -65.47741 8.386527 -89.646194 65.11153 42.213394] 输出数据(dstLocal) float数据类型: [-83.46854 53.82648 153.47137 458.34186 -25.15958 717.16956 195.33458 253.28036 93.9021 21.530592 153.47137 -0. 0. 448.23096 -455.7807 0. -62.601402 0. 61.38855 -196.43222 -16.773054 -717.16956 325.55762 84.42679 20.867134 21.530592 122.7771 -327.38705 -33.54611 -358.58478 -520.8922 126.64018 ]PER_TOKEN/PER_GROUP场景调用示例如下。// 注意m,n需从外部传入 constexpr static bool isReuseSource false; constexpr static AscendDeQuantConfig config {has_offset, -1}; constexpr static AscendDeQuantPolicy policy AscendDeQuantPolicy::PER_TOKEN; // 可修改枚举值以使能PER_GROUP AscendDeQuantParam para; para.m m; para.n n; para.calCount calCount; AscendDequantdstType, srcType, scaleType, config, policy(dstLocal, srcLocal, scaleLocal, offsetLocal, para);【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

更多文章