diff --git a/examples/semantic_indexing/README.md b/examples/semantic_indexing/README.md index 5e70cd25d059..81b9b722ce3c 100644 --- a/examples/semantic_indexing/README.md +++ b/examples/semantic_indexing/README.md @@ -232,6 +232,32 @@ python -u -m paddle.distributed.launch --gpus "0" \ 0.9800204038619995 ``` +## 使用Faster Transformer进行快速预测 +不同于上述原生预测的是,使用FasterTransformer的预测是使用了集成了Faster Transformer库的Paddle自定算子,在一定的配置下,可以对TransformerEncoder的预测进行加速。 + +```shell +python -u -m paddle.distributed.launch --gpus "0" faster_predict.py \ + --init_from_params "batch_neg_v1.0/model_state.pdparams" \ + --output_emb_size 256 \ + --batch_size 32 \ + --max_seq_length 64 \ + --text_pair_file ${your_input_file} \ + +``` + +执行上述操作后,可以得到与原生预测非常接近的结果,在float32下,且batch_size=32, max_seq_len=64时,二种方式预测下最终余弦相似度的最大绝对误差约为3.93e-6。 + +通过比较,可以得到在不同batch_size, max_seq_len下,使用集成了FasterTransformer的高性能算子可以对Encoder部分的推理进行加速(其余参数都与默认值相同)。在NVIDIA Tesla V100,16GB的机器上,使用单卡预测得到部分性能数据如下,从表中可以看出在更小的batch_size和max_seq_len上,使用FasterTransformer预测更有优势。 + +| batch size | max_seq_len | FT加速算子(单位:s) | Paddle原生(单位:s) | +| ---------- | ----------- | ------------------- | ------------------- | +| 16 | 16 | 22.645333290100098 | 51.55912470817566 | +| 16 | 32 | 27.326106071472168 | 57.17143130302429 | +| 16 | 64 | 33.31318140029907 | 52.44770574569702 | +| 32 | 16 | 12.891342163085938 | 22.621662139892578 | +| 32 | 32 | 17.206310987472534 | 22.18772006034851 | + + ## 模型介绍 简要介绍 In-batch negatives 策略和 HardestNeg 策略思路 diff --git a/examples/semantic_indexing/data.py b/examples/semantic_indexing/data.py index 5c523126b649..9eb9a1d0d499 100644 --- a/examples/semantic_indexing/data.py +++ b/examples/semantic_indexing/data.py @@ -26,7 +26,6 @@ def create_dataloader(dataset, trans_fn=None): if trans_fn: dataset = dataset.map(trans_fn) - shuffle = True if mode == 'train' else False if mode == 'train': batch_sampler = paddle.io.DistributedBatchSampler( @@ -42,7 +41,10 @@ def create_dataloader(dataset, return_list=True) -def convert_example(example, tokenizer, max_seq_length=512): +def convert_example(example, + tokenizer, + max_seq_length=512, + pad_to_max_seq_len=False): """ Builds model inputs from a sequence. @@ -65,11 +67,13 @@ def convert_example(example, tokenizer, max_seq_length=512): result = [] for key, text in example.items(): - encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length) + encoded_inputs = tokenizer( + text=text, + max_seq_len=max_seq_length, + pad_to_max_seq_len=pad_to_max_seq_len) input_ids = encoded_inputs["input_ids"] token_type_ids = encoded_inputs["token_type_ids"] result += [input_ids, token_type_ids] - return result diff --git a/examples/semantic_indexing/faster_predict.py b/examples/semantic_indexing/faster_predict.py new file mode 100644 index 000000000000..e54643af9834 --- /dev/null +++ b/examples/semantic_indexing/faster_predict.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial +import argparse +from pprint import pprint +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import TransformerEncoder, TransformerEncoderLayer + +from paddlenlp.transformers import ErnieTokenizer, ErnieModel +from paddlenlp.data import Pad, Tuple +from paddlenlp.datasets import load_dataset +from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder + +from data import read_text_pair, convert_example, create_dataloader + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text_pair_file", + type=str, + required=True, + help="The full path of input file") + parser.add_argument( + "--output_emb_size", + default=None, + type=int, + help="output_embedding_size") + parser.add_argument( + "--params_path", + type=str, + required=True, + help="The path to model parameters to be loaded.") + parser.add_argument( + "--max_seq_length", + default=64, + type=int, + help="The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded." + ) + parser.add_argument( + "--dropout", default=0.0, type=float, help="Dropout probability.") + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--seed", default=42, type=int, help="Random seed.") + parser.add_argument( + "--pad_to_max_seq_len", + action="store_true", + help="Whether to pad to max_seq_len.") + + args = parser.parse_args() + return args + + +class SemanticIndexingPredictor(nn.Layer): + def __init__(self, + pretrained_model, + output_emb_size, + n_layer=12, + n_head=12, + hidden_size=768, + dim_feedforward=3072, + activation="relu", + bos_id=0, + dropout=0, + max_seq_len=128, + is_gelu=False): + super(SemanticIndexingPredictor, self).__init__() + size_per_head = hidden_size // n_head + self.bos_id = bos_id + self.ptm = pretrained_model + self.dropout = nn.Dropout(dropout if dropout is not None else 0.0) + self.output_emb_size = output_emb_size + if output_emb_size > 0: + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.TruncatedNormal(std=0.02)) + self.emb_reduce_linear = paddle.nn.Linear( + 768, output_emb_size, weight_attr=weight_attr) + encoder_layer = TransformerEncoderLayer( + hidden_size, n_head, dim_feedforward, dropout=dropout) + self.ptm.encoder = TransformerEncoder(encoder_layer, n_layer) + + def get_pooled_embedding(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + src_mask = (input_ids != self.bos_id + ).astype(self.ptm.encoder.layers[0].norm1.bias.dtype) + src_mask = paddle.unsqueeze(src_mask, axis=[1, 2]) + src_mask.stop_gradient = True + + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=1) + position_ids = seq_length - ones + position_ids.stop_gradient = True + + embedding_output = self.ptm.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) + sequence_output = self.ptm.encoder(embedding_output, src_mask) + cls_embedding = self.ptm.pooler(sequence_output) + + if self.output_emb_size > 0: + cls_embedding = self.emb_reduce_linear(cls_embedding) + cls_embedding = self.dropout(cls_embedding) + cls_embedding = F.normalize(cls_embedding, p=2, axis=-1) + + return cls_embedding + + def forward(self, + query_input_ids, + title_input_ids, + query_token_type_ids=None, + query_position_ids=None, + query_attention_mask=None, + title_token_type_ids=None, + title_position_ids=None, + title_attention_mask=None): + query_cls_embedding = self.get_pooled_embedding( + query_input_ids, query_token_type_ids, query_position_ids, + query_attention_mask) + title_cls_embedding = self.get_pooled_embedding( + title_input_ids, title_token_type_ids, title_position_ids, + title_attention_mask) + cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding, + axis=-1) + return cosine_sim + + def load(self, init_from_params): + if init_from_params and os.path.isfile(init_from_params): + state_dict = paddle.load(init_from_params) + self.set_state_dict(state_dict) + print("Loaded parameters from %s" % init_from_params) + else: + raise ValueError( + "Please set --params_path with correct pretrained model file") + + +def do_predict(args): + place = paddle.set_device("gpu") + paddle.seed(args.seed) + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + pad_to_max_seq_len=args.pad_to_max_seq_len) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment + Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment + ): [data for data in fn(samples)] + + valid_ds = load_dataset( + read_text_pair, data_path=args.text_pair_file, lazy=False) + + valid_data_loader = create_dataloader( + valid_ds, + mode="predict", + batch_size=args.batch_size, + batchify_fn=batchify_fn, + trans_fn=trans_func) + + pretrained_model = ErnieModel.from_pretrained("ernie-1.0") + + model = SemanticIndexingPredictor( + pretrained_model, + args.output_emb_size, + max_seq_len=args.max_seq_length, + dropout=args.dropout) + model.eval() + model.load(args.params_path) + model = enable_faster_encoder(model) + cosine_sims = [] + for batch_data in valid_data_loader: + query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch_data + query_input_ids = paddle.to_tensor(query_input_ids) + query_token_type_ids = paddle.to_tensor(query_token_type_ids) + title_input_ids = paddle.to_tensor(title_input_ids) + title_token_type_ids = paddle.to_tensor(title_token_type_ids) + batch_cosine_sim = model( + query_input_ids=query_input_ids, + title_input_ids=title_input_ids, + query_token_type_ids=query_token_type_ids, + title_token_type_ids=title_token_type_ids).numpy() + cosine_sims.append(batch_cosine_sim) + + cosine_sims = np.concatenate(cosine_sims, axis=0) + for cosine in cosine_sims: + print('{}'.format(cosine)) + model = disable_faster_encoder(model) + + +if __name__ == "__main__": + args = parse_args() + pprint(args) + do_predict(args) diff --git a/examples/semantic_indexing/predict.py b/examples/semantic_indexing/predict.py index 5c2d9a72b4ae..2a6f289fb15d 100644 --- a/examples/semantic_indexing/predict.py +++ b/examples/semantic_indexing/predict.py @@ -38,6 +38,7 @@ parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.") parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size") parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.") +parser.add_argument("--pad_to_max_seq_len", action="store_true", help="Whether to pad to max seq length.") args = parser.parse_args() # yapf: enable @@ -86,7 +87,8 @@ def predict(model, data_loader): trans_func = partial( convert_example, tokenizer=tokenizer, - max_seq_length=args.max_seq_length) + max_seq_length=args.max_seq_length, + pad_to_max_seq_len=args.pad_to_max_seq_len) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input diff --git a/paddlenlp/ops/CMakeLists.txt b/paddlenlp/ops/CMakeLists.txt index cf4e4c782164..86ff7507f037 100644 --- a/paddlenlp/ops/CMakeLists.txt +++ b/paddlenlp/ops/CMakeLists.txt @@ -27,6 +27,7 @@ option(WITH_TRANSFORMER "Compile with Transformer" option(WITH_GPT "Compile with GPT" OFF) option(WITH_UNIFIED "Compile with Unified Transformer" ON) option(WITH_DECODER "Compile with Transformer Decoder" ON) +option(WITH_ENCODER "Compile with Transformer Encoder" ON) if(NOT WITH_GPU) message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ") @@ -44,12 +45,16 @@ if(WITH_UNIFIED) list(APPEND decoding_op_files fusion_unified_decoding_op.cc fusion_unified_decoding_op.cu) endif() +if(WITH_ENCODER) + list(APPEND decoding_op_files fusion_encoder_op.cc fusion_encoder_op.cu) +endif() + if(WITH_DECODER) list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu) endif() -if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER) - message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON must be set to use FasterTransformer. ") +if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER) + message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON must be set to use FasterTransformer. ") endif() set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) @@ -161,6 +166,13 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME} file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/topk_kernels.cuh topk_kernels_cuh_src) file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/topk_kernels.cuh topk_kernels_cuh_dst) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/open_attention.h open_attention_h_dst) + +file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src) +file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst) + +set(OPT_OPEN_ATTN_COMMAND sed -i -e "370,392d" -e "410,454d" -e "229d" ${open_attention_h_dst}) +#set(OPT_BERT_ENCODER_COMMAND sed -i -e "552,592d" -e "118a bool is_gelu_=true;" ${bert_encoder_transformer_h_dst}) # TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it set(MUTE_COMMAND grep -rl "printf(\"\\[WARNING\\]" ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/ | xargs -i{} sed -i "s/printf(\"\\WWARNING\\W decoding[^)]\\{1,\\})/ /" {}) @@ -174,6 +186,7 @@ set(FT_PATCH_COMMAND && cp ${beamsearch_h_src} ${trans_dst} && cp ${sampling_h_src} ${trans_dst} && cp ${arguments_h_src} ${trans_dst} + && cp ${bert_encoder_transformer_h_src} ${bert_encoder_transformer_h_dst} && cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst} && cat ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst} && cat ${decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst} @@ -182,6 +195,7 @@ set(FT_PATCH_COMMAND && cat ${trans_decoder_h_src} >> ${open_decoder_h_dst} && cat ${trans_cuda_kernels_h_src} >> ${cuda_kernels_h_dst} && cat ${trans_decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst} + && ${OPT_OPEN_ATTN_COMMAND} && ${MUTE_COMMAND} ) @@ -282,3 +296,5 @@ if(ON_INFER AND WITH_GPT) endif() add_subdirectory(faster_transformer) + + diff --git a/paddlenlp/ops/__init__.py b/paddlenlp/ops/__init__.py index 12312208771f..bfb00a59c06b 100644 --- a/paddlenlp/ops/__init__.py +++ b/paddlenlp/ops/__init__.py @@ -15,6 +15,13 @@ from .faster_transformer.transformer.decoding import * from .faster_transformer.transformer.faster_transformer import * from .faster_transformer.transformer.decoder import * +from .faster_transformer.transformer.encoder import * from .einsum import * from .distributed import * from . import optimizer + +paddle.nn.TransformerEncoderLayer._ft_forward = encoder_layer_forward +paddle.nn.TransformerEncoder._ft_forward = encoder_forward + +paddle.nn.TransformerEncoderLayer._ori_forward = paddle.nn.TransformerEncoderLayer.forward +paddle.nn.TransformerEncoder._ori_forward = paddle.nn.TransformerEncoder.forward diff --git a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt index e2ec6bdc0dad..0d2a9678416f 100644 --- a/paddlenlp/ops/faster_transformer/src/CMakeLists.txt +++ b/paddlenlp/ops/faster_transformer/src/CMakeLists.txt @@ -146,6 +146,7 @@ if(ON_INFER) endif(NOT WIN32) cuda_add_library(pd_infer_custom_op ${decoding_op_files} ${decoder_op_files} SHARED) + add_dependencies(pd_infer_custom_op extern_${THIRD_PARTY_NAME}) string(REPLACE "/" ";" DEMO_PATH ${DEMO}) @@ -268,9 +269,9 @@ else(ON_INFER) add_library(decoding_op SHARED ${decoding_op_files}) add_dependencies(decoding_op extern_${THIRD_PARTY_NAME} boost) - target_link_libraries(decoding_op PRIVATE -lcublas -lcudart ${lib_link} ${ft_lib_link}) + target_link_libraries(decoding_op PRIVATE -lcublas -lcudart ${lib_link} ${ft_lib_link} -lencoder) add_library(decoder_op SHARED ${decoder_op_files}) add_dependencies(decoder_op extern_${THIRD_PARTY_NAME} boost) target_link_libraries(decoder_op PRIVATE -lcublas -lcudart -ldecoder ${lib_link}) -endif() \ No newline at end of file +endif() diff --git a/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cc b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cc new file mode 100644 index 000000000000..7559129ec031 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cc @@ -0,0 +1,198 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "fusion_encoder_op.h" + +std::vector EncoderForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_query_weight, + const paddle::Tensor& attn_query_bias, + const paddle::Tensor& attn_key_weight, + const paddle::Tensor& attn_key_bias, + const paddle::Tensor& attn_value_weight, + const paddle::Tensor& attn_value_bias, + const paddle::Tensor& attn_output_weight, + const paddle::Tensor& attn_output_bias, + const paddle::Tensor& attn_mask, + const paddle::Tensor& attn_output_layernorm_weight, + const paddle::Tensor& attn_output_layernorm_bias, + const paddle::Tensor& output_layernorm_weight, + const paddle::Tensor& output_layernorm_bias, + const paddle::Tensor& ffn_intermediate_weight, + const paddle::Tensor& ffn_intermediate_bias, + const paddle::Tensor& ffn_output_weight, + const paddle::Tensor& ffn_output_bias, + // const paddle::Tensor& sequence_id_offset, + // const paddle::Tensor& trt_seqlen_offset, + // const paddle::Tensor& amax_list, + const int64_t& head_num, + const int64_t& size_per_head, + const bool& is_gelu, + const bool& remove_padding, + const int64_t& int8_mode, + const int64_t& num_layer, + const int64_t& layer_idx, + const bool& allow_gemm_test, + const bool& use_trt_kernel) { + if (input.place() == paddle::PlaceType::kGPU) { + auto shape = input.shape(); + auto encoder_out = paddle::Tensor(paddle::PlaceType::kGPU, shape); + return EncoderCUDAForward(input, + attn_query_weight, + attn_query_bias, + attn_key_weight, + attn_key_bias, + attn_value_weight, + attn_value_bias, + attn_output_weight, + attn_output_bias, + attn_mask, + attn_output_layernorm_weight, + attn_output_layernorm_bias, + output_layernorm_weight, + output_layernorm_bias, + ffn_intermediate_weight, + ffn_intermediate_bias, + ffn_output_weight, + ffn_output_bias, + // sequence_id_offset, + // trt_seqlen_offset, + // amax_list, + encoder_out, + head_num, + size_per_head, + is_gelu, + remove_padding, + int8_mode, // no support now + num_layer, + layer_idx, + allow_gemm_test, + use_trt_kernel); + } else { + PD_THROW("Not implemented place. Only GPU is supported. "); + } +} + +std::vector> EncoderInferShape( + const std::vector& input_shape, + const std::vector& attn_query_weight_shape, + const std::vector& attn_query_bias_shape, + const std::vector& attn_key_weight_shape, + const std::vector& attn_key_bias_shape, + const std::vector& attn_value_weight_shape, + const std::vector& attn_value_bias_shape, + const std::vector& attn_output_weight_shape, + const std::vector& attn_output_bias_shape, + const std::vector& attn_mask_shape, + const std::vector& attn_output_layernorm_weight_shape, + const std::vector& attn_output_layernorm_bias_shape, + const std::vector& output_layernorm_weight_shape, + const std::vector& output_layernorm_bias_shape, + const std::vector& ffn_intermediate_weight_shape, + const std::vector& ffn_intermediate_bias_shape, + const std::vector& ffn_output_weight_shape, + const std::vector& ffn_output_bias_shape, + // const std::vector& sequence_id_offset, + // const std::vector& trt_seqlen_offset, + // const std::vector& amax_list_shape, + const int64_t& head_num, + const int64_t& size_per_head, + const bool& is_gelu, + const bool& remove_padding, + const int64_t& int8_mode, // no support now + const int64_t& num_layer, + const int64_t& layer_idx, + const bool& allow_gemm_test, + const bool& use_trt_kernel) { + return {input_shape}; +} + + +std::vector EncoderInferDtype( + const paddle::DataType& input, + const paddle::DataType& attn_query_weight, + const paddle::DataType& attn_query_bias, + const paddle::DataType& attn_key_weight, + const paddle::DataType& attn_key_bias, + const paddle::DataType& attn_value_weight, + const paddle::DataType& attn_value_bias, + const paddle::DataType& attn_output_weight, + const paddle::DataType& attn_output_bias, + const paddle::DataType& attn_mask, + const paddle::DataType& attn_output_layernorm_weight, + const paddle::DataType& attn_output_layernorm_bias, + const paddle::DataType& output_layernorm_weight, + const paddle::DataType& output_layernorm_bias, + const paddle::DataType& ffn_intermediate_weight, + const paddle::DataType& ffn_intermediate_bias, + const paddle::DataType& ffn_output_weight, + const paddle::DataType& ffn_output_bias) { + // const paddle::DataType& sequence_id_offset, + // const paddle::DataType& trt_seqlen_offset, + // const paddle::DataType& amax_list) { + switch (input) { + case paddle::DataType::FLOAT16: { + return {input}; + } + case paddle::DataType::FLOAT32: { + return {input}; + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } + } +} + +PD_BUILD_OP(fusion_encoder) + .Inputs({ + "Input", + "SelfQueryWeight", + "SelfQueryBias", + "SelfKeyWeight", + "SelfKeyBias", + "SelfValueWeight", + "SelfValueBias", + "SelfAttnOutputWeight", + "SelfAttnOutputBias", + "SelfAttnMask", + "SelfAttnOutputLayernormWeight", + "SelfAttnOutputLayernormBias", + "OutputLayernormWeight", + "OutputLayernormBias", + "FFNInterWeight", + "FFNInterBias", + "FFNOutputWeight", + "FFNOutputBias", + // "SequenceIdOffset", + // "TRTSeqLenOffset", + // "AmaxList", + }) + .Outputs({"EncoderOut"}) + .Attrs({"head_num: int64_t", + "size_per_head: int64_t", + "is_gelu: bool", + "remove_padding: bool", + "int8_mode: int64_t", + "num_layer: int64_t", + "layer_idx: int64_t", + "allow_gemm_test: bool", + "use_trt_kernel: bool"}) + .SetKernelFn(PD_KERNEL(EncoderForward)) + .SetInferShapeFn(PD_INFER_SHAPE(EncoderInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(EncoderInferDtype)); diff --git a/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cu b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cu new file mode 100644 index 000000000000..22aa29733b30 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cu @@ -0,0 +1,296 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fastertransformer/cuda/cub/cub.cuh" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/faster_transformer.h" +#include "fusion_encoder_op.h" +#include "pd_traits.h" + + +template +std::vector encoder_kernel( + const paddle::Tensor& input, + const paddle::Tensor& attn_query_weight, + const paddle::Tensor& attn_query_bias, + const paddle::Tensor& attn_key_weight, + const paddle::Tensor& attn_key_bias, + const paddle::Tensor& attn_value_weight, + const paddle::Tensor& attn_value_bias, + const paddle::Tensor& attn_output_weight, + const paddle::Tensor& attn_output_bias, + const paddle::Tensor& attn_mask, + const paddle::Tensor& attn_output_layernorm_weight, + const paddle::Tensor& attn_output_layernorm_bias, + const paddle::Tensor& output_layernorm_weight, + const paddle::Tensor& output_layernorm_bias, + const paddle::Tensor& ffn_intermediate_weight, + const paddle::Tensor& ffn_intermediate_bias, + const paddle::Tensor& ffn_output_weight, + const paddle::Tensor& ffn_output_bias, + // const paddle::Tensor& sequence_id_offset, + // const paddle::Tensor& trt_seqlen_offset, + // const paddle::Tensor& amax_list, + paddle::Tensor& encoder_out, + int64_t head_num_, + int64_t size_per_head_, + bool is_gelu, + bool remove_padding, + int64_t int8_mode, // no support now + int64_t num_layer_, + int64_t layer_idx_, + bool allow_gemm_test, + bool use_trt_kernel_, + cublasHandle_t cublas_handle_, + cudaStream_t stream) { + int batch_size_ = input.shape()[0]; + int max_seq_len_ = input.shape()[1]; + typedef PDTraits traits_; + typedef BertEncoderTransformerTraits + EncoderTraits_; + fastertransformer::Allocator* allocator_ = + new fastertransformer::Allocator(stream); + + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t_; + + EncoderInitParam encoder_param; + + encoder_param.stream = stream; + encoder_param.cublas_handle = cublas_handle_; + encoder_param.from_tensor = + reinterpret_cast(input.data()); + + encoder_param.to_tensor = + reinterpret_cast(input.data()); + encoder_param.transformer_out = reinterpret_cast( + encoder_out.mutable_data(input.place())); + // self attn + encoder_param.self_attention.query_weight.kernel = + reinterpret_cast(attn_query_weight.data()); + encoder_param.self_attention.query_weight.bias = + reinterpret_cast(attn_query_bias.data()); + encoder_param.self_attention.key_weight.kernel = + reinterpret_cast(attn_key_weight.data()); + encoder_param.self_attention.key_weight.bias = + reinterpret_cast(attn_key_bias.data()); + encoder_param.self_attention.value_weight.kernel = + reinterpret_cast(attn_value_weight.data()); + encoder_param.self_attention.value_weight.bias = + reinterpret_cast(attn_value_bias.data()); + encoder_param.attr_mask = + reinterpret_cast(attn_mask.data()); + encoder_param.self_attention.attention_output_weight.kernel = + reinterpret_cast(attn_output_weight.data()); + encoder_param.self_attention.attention_output_weight.bias = + reinterpret_cast(attn_output_bias.data()); + encoder_param.self_layernorm.gamma = reinterpret_cast( + attn_output_layernorm_weight.data()); + encoder_param.self_layernorm.beta = reinterpret_cast( + attn_output_layernorm_bias.data()); + encoder_param.ffn.intermediate_weight.kernel = + reinterpret_cast( + ffn_intermediate_weight.data()); + encoder_param.ffn.intermediate_weight.bias = + reinterpret_cast(ffn_intermediate_bias.data()); + + encoder_param.ffn.output_weight.kernel = + reinterpret_cast(ffn_output_weight.data()); + encoder_param.ffn.output_weight.bias = + reinterpret_cast(ffn_output_bias.data()); + + encoder_param.ffn_layernorm.gamma = reinterpret_cast( + output_layernorm_weight.data()); + encoder_param.ffn_layernorm.beta = + reinterpret_cast(output_layernorm_bias.data()); + int valid_word_num; + // if (remove_padding) { + // valid_word_num = sequence_id_offset.shape()[0]; + // encoder_param.sequence_id_offset = sequence_id_offset.data(); + // } else { + encoder_param.sequence_id_offset = nullptr; + valid_word_num = batch_size_ * max_seq_len_; + // } + encoder_param.valid_word_num = valid_word_num; + + encoder_param.trt_seqlen_offset = nullptr; // trt_seqlen_offset.data(); + encoder_param.trt_seqlen_size = batch_size_ + 1; + // static_cast(trt_seqlen_offset.shape()[0]); + // int8_mode = 0; + // if (int8_mode != 0) { + // encoder_param.amaxList = + // reinterpret_cast(amax_list.data()); + // encoder_param.layer_num = num_layer_; + // encoder_param.layer_idx = layer_idx_; + // } else { + encoder_param.amaxList = nullptr; + // } + + BertEncoderTransformer* encoder = + new BertEncoderTransformer(int8_mode, allow_gemm_test); + + encoder->allocateBuffer(allocator_, + batch_size_, + max_seq_len_, + max_seq_len_, + head_num_, + size_per_head_, + is_gelu, + use_trt_kernel_); + encoder->initialize(encoder_param); + encoder->forward(); + encoder->freeBuffer(); + delete allocator_; + delete encoder; + + return {encoder_out}; +} + + +std::vector EncoderCUDAForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_query_weight, + const paddle::Tensor& attn_query_bias, + const paddle::Tensor& attn_key_weight, + const paddle::Tensor& attn_key_bias, + const paddle::Tensor& attn_value_weight, + const paddle::Tensor& attn_value_bias, + const paddle::Tensor& attn_output_weight, + const paddle::Tensor& attn_output_bias, + const paddle::Tensor& attn_mask, + const paddle::Tensor& attn_output_layernorm_weight, + const paddle::Tensor& attn_output_layernorm_bias, + const paddle::Tensor& output_layernorm_weight, + const paddle::Tensor& output_layernorm_bias, + const paddle::Tensor& ffn_intermediate_weight, + const paddle::Tensor& ffn_intermediate_bias, + const paddle::Tensor& ffn_output_weight, + const paddle::Tensor& ffn_output_bias, + // const paddle::Tensor& sequence_id_offset, + // const paddle::Tensor& trt_seqlen_offset, + // const paddle::Tensor& amax_list, + paddle::Tensor& encoder_out, + int64_t head_num, + int64_t size_per_head, + bool is_gelu, + bool remove_padding, + int64_t int8_mode, + int64_t num_layer, + int64_t layer_idx, + bool allow_gemm_test, + bool use_trt_kernel) { + auto stream = input.stream(); + cublasHandle_t cublas_handle_; + cublasCreate(&cublas_handle_); + cublasSetStream(cublas_handle_, stream); + + std::vector ret; + + switch (input.type()) { + case paddle::DataType::FLOAT16: { + ret = encoder_kernel( + input, + attn_query_weight, + attn_query_bias, + attn_key_weight, + attn_key_bias, + attn_value_weight, + attn_value_bias, + attn_output_weight, + attn_output_bias, + attn_mask, + attn_output_layernorm_weight, + attn_output_layernorm_bias, + output_layernorm_weight, + output_layernorm_bias, + ffn_intermediate_weight, + ffn_intermediate_bias, + ffn_output_weight, + ffn_output_bias, + // sequence_id_offset, + // trt_seqlen_offset, + // amax_list, + encoder_out, + head_num, + size_per_head, + is_gelu, + remove_padding, + int8_mode, + num_layer, + layer_idx, + allow_gemm_test, + use_trt_kernel, + cublas_handle_, + stream); + + break; + } + case paddle::DataType::FLOAT32: { + ret = encoder_kernel( + input, + attn_query_weight, + attn_query_bias, + attn_key_weight, + attn_key_bias, + attn_value_weight, + attn_value_bias, + attn_output_weight, + attn_output_bias, + attn_mask, + attn_output_layernorm_weight, + attn_output_layernorm_bias, + output_layernorm_weight, + output_layernorm_bias, + ffn_intermediate_weight, + ffn_intermediate_bias, + ffn_output_weight, + ffn_output_bias, + // sequence_id_offset, + // trt_seqlen_offset, + // amax_list, + encoder_out, + head_num, + size_per_head, + is_gelu, + remove_padding, + int8_mode, + num_layer, + layer_idx, + allow_gemm_test, + use_trt_kernel, + cublas_handle_, + stream); + break; + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } + } + + cublasDestroy(cublas_handle_); + return ret; +} diff --git a/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.h b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.h new file mode 100644 index 000000000000..ffd6191aa779 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/src/fusion_encoder_op.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +#include "fastertransformer/bert_encoder_transformer.h" +#include "fastertransformer/common.h" + +#ifdef PADDLE_ON_INFERENCE +#include "paddle/include/experimental/ext_all.h" +#else +#include "paddle/extension.h" +#endif + + +std::vector EncoderCUDAForward( + const paddle::Tensor& input, + const paddle::Tensor& attn_query_weight, + const paddle::Tensor& attn_query_bias, + const paddle::Tensor& attn_key_weight, + const paddle::Tensor& attn_key_bias, + const paddle::Tensor& attn_value_weight, + const paddle::Tensor& attn_value_bias, + const paddle::Tensor& attn_output_weight, + const paddle::Tensor& attn_output_bias, + const paddle::Tensor& attn_mask, + const paddle::Tensor& attn_output_layernorm_weight, + const paddle::Tensor& attn_output_layernorm_bias, + const paddle::Tensor& output_layernorm_weight, + const paddle::Tensor& output_layernorm_bias, + const paddle::Tensor& ffn_intermediate_weight, + const paddle::Tensor& ffn_intermediate_bias, + const paddle::Tensor& ffn_output_weight, + const paddle::Tensor& ffn_output_bias, + // const paddle::Tensor& sequence_id_offset, + // const paddle::Tensor& trt_seqlen_offset, + // const paddle::Tensor& amax_list, + paddle::Tensor& encoder_out, + int64_t head_num_, + int64_t size_per_head_, + bool is_gelu, + bool remove_padding, + int64_t int8_mode, // no support now + int64_t num_layer_, + int64_t layer_idx_, + bool allow_gemm_test, + bool use_trt_kernel_); diff --git a/paddlenlp/ops/faster_transformer/transformer/encoder.py b/paddlenlp/ops/faster_transformer/transformer/encoder.py new file mode 100644 index 000000000000..e99d6ca0e784 --- /dev/null +++ b/paddlenlp/ops/faster_transformer/transformer/encoder.py @@ -0,0 +1,307 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import paddle +from paddle.fluid.layer_helper import LayerHelper +import paddle.nn as nn +from paddle.nn import TransformerEncoder, TransformerEncoderLayer + +from paddlenlp.utils.log import logger +from paddlenlp.ops.ext_utils import load + + +def infer_transformer_encoder( + input, + q_weight, + q_bias, + k_weight, + k_bias, + v_weight, + v_bias, + attn_out_weight, + attn_out_bias, + attn_mask, + attn_ln_weight, + attn_ln_bias, + out_ln_weight, + out_ln_bias, + ffn_inter_weight, + ffn_inter_bias, + ffn_out_weight, + ffn_out_bias, + # sequence_id_offset, + # trt_seqlen_offset, + # amax_list, + n_head, + size_per_head, + n_layer=12, + is_gelu=True, + remove_padding=False, + int8_mode=0, + layer_idx=0, + allow_gemm_test=False, + use_trt_kernel=False): + """ + Fusion Encoder API intergrating Encoder inference in FasterTransformer. It + accepts the weight and bias of TransformerEncoder and some other parameters + for inference. + """ + helper = LayerHelper('fusion_encoder', **locals()) + inputs = { + 'Input': input, + 'SelfQueryWeight': q_weight, + 'SelfQueryBias': q_bias, + 'SelfKeyWeight': k_weight, + 'SelfKeyBias': k_bias, + 'SelfValueWeight': v_weight, + 'SelfValueBias': v_bias, + 'SelfAttnOutputWeight': attn_out_weight, + 'SelfAttnOutputBias': attn_out_bias, + "SelfAttnMask": attn_mask, + 'SelfAttnOutputLayernormWeight': attn_ln_weight, + 'SelfAttnOutputLayernormBias': attn_ln_bias, + 'OutputLayernormWeight': out_ln_weight, + 'OutputLayernormBias': out_ln_bias, + 'FFNInterWeight': ffn_inter_weight, + 'FFNInterBias': ffn_inter_bias, + 'FFNOutputWeight': ffn_out_weight, + 'FFNOutputBias': ffn_out_bias, + # "SequenceIdOffset": sequence_id_offset, + # "TRTSeqLenOffset": trt_seqlen_offset, + # 'AmaxList': amax_list + } + attrs = { + 'head_num': n_head, + 'size_per_head': size_per_head, + 'is_gelu': is_gelu, + "remove_padding": remove_padding, + 'int8_mode': int8_mode, + 'num_layer': n_layer, + 'layer_idx': layer_idx, + 'allow_gemm_test': allow_gemm_test, + 'use_trt_kernel': use_trt_kernel, + } + encoder_out = helper.create_variable(dtype=input.dtype) + outputs = {"EncoderOut": encoder_out} + + helper.append_op( + type='fusion_encoder', inputs=inputs, outputs=outputs, attrs=attrs) + return encoder_out + + +def encoder_layer_forward(self, + src, + src_mask, + cache=None, + sequence_id_offset=None, + trt_seq_len=None): + """ + Redefines `forward` function of `paddle.nn.TransformerEncoderLayer` for + integrating FasterTransformer for inference. + + The original `forward` function would not be replaced unless + `enable_faster_encoder` is called by objects of its base class. After + replacing, objects of `paddle.nn.TransformerEncoderLayer` also have the + same member variables as before. + + After inference, `disable_faster_encoder` could be called to restore the + `forward` function of `paddle.nn.TransformerEncoder` and + `paddle.nn.TransformerEncoder`. + + Args: + src (Tensor): + The input of Transformer encoder layer. It is a tensor with shape + `[batch_size, sequence_length, d_model]`. The data type should be + float32 or float64. + src_mask (Tensor, optional): + A tensor used in multi-head attention to prevents attention to some + unwanted positions, usually the paddings or the subsequent + positions. It is a tensor with shape broadcasted to + `[batch_size, n_head, sequence_length, sequence_length]`. When the + data type is bool, the unwanted positions have `False` values and + the others have `True` values. When the data type is int, the + unwanted positions have 0 values and the others have 1 values. When + the data type is float, the unwanted positions have `-INF` values + and the others have 0 values. It can be None when nothing wanted or + needed to be prevented attention to. Defaults to None. + + Returns: + src(Tensor|tuple): + It is a tensor that has the same shape and data type as `enc_input`, + representing the output of Transformer encoder layer. Or a tuple if + `cache` is not None, except for encoder layer output, the tuple + includes the new cache which is same as input `cache` argument but + `incremental_cache` has an incremental length. See + `paddle.nn.MultiHeadAttention.gen_cache` and + `paddle.nn.MultiHeadAttention.forward` for more details. + """ + if cache is not None: + raise NotImplementedError("cache in encoder is not supported now") + src = infer_transformer_encoder( + input=src, + q_weight=self.self_attn.q_proj.weight, + q_bias=self.self_attn.q_proj.bias, + k_weight=self.self_attn.k_proj.weight, + k_bias=self.self_attn.k_proj.bias, + v_weight=self.self_attn.v_proj.weight, + v_bias=self.self_attn.v_proj.bias, + attn_out_weight=self.self_attn.out_proj.weight, + attn_out_bias=self.self_attn.out_proj.bias, + attn_mask=src_mask, + attn_ln_weight=self.norm1.weight, + attn_ln_bias=self.norm1.bias, + out_ln_weight=self.norm2.weight, + out_ln_bias=self.norm2.bias, + ffn_inter_weight=self.linear1.weight, + ffn_inter_bias=self.linear1.bias, + ffn_out_weight=self.linear2.weight, + ffn_out_bias=self.linear2.bias, + # sequence_id_offset=paddle.to_tensor([]), + # trt_seqlen_offset=paddle.to_tensor([]), + # amax_list=paddle.to_tensor([]), # int8 mode is not supported. + n_head=self._config['nhead'], + size_per_head=self._config['d_model'] // self._config['nhead'], + is_gelu=self._config['activation'] == 'gelu') + return src + + +def encoder_forward(self, src, src_mask=None, cache=None): + """ + Redefines `forward` function of `paddle.nn.TransformerEncoder` for + integrating FasterTransformer for inference. + + The original `forward` function would not be replaced unless + `enable_faster_encoder` is called by objects of its base class. After + replacing, objects of `paddle.nn.TransformerEncoder` also have the same + member variables as before. + + After inference, `disable_faster_encoder` could be called to restore the + `forward` function of `paddle.nn.TransformerEncoder` and + `paddle.nn.TransformerEncoder`. + + Args: + src (Tensor): + The input of Transformer encoder. It is a tensor + with shape `[batch_size, sequence_length, d_model]`. The data + type should be float32 or float64. + src_mask (Tensor, optional): + A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent + positions. It is a tensor with shape broadcasted to + `[batch_size, n_head, sequence_length, sequence_length]`. When the + data type is bool, the unwanted positions have `False` values and + the others have `True` values. When the data type is int, the + unwanted positions have 0 values and the others have 1 values. + When the data type is float, the unwanted positions have `-INF` + values and the others have 0 values. It can be None when nothing + wanted or needed to be prevented attention to. Default None. + + Returns: + output (Tensor|tuple): + It is a tensor that has the same shape and data type as `src`, + representing the output of Transformer encoder. Or a tuple if + `cache` is not None, except for encoder output, the tuple includes + the new cache which is same as input `cache` argument but + `incremental_cache` in it has an incremental length. See + `paddle.nn.MultiHeadAttention.gen_cache` and + `paddle.nn.MultiHeadAttention.forward` for more details. + """ + + max_seq_len = src.shape[1] + # broadcast + src_mask = paddle.concat(x=[src_mask] * max_seq_len, axis=2) + output = src + for i, layer in enumerate(self.layers): + output = layer(output, src_mask) + if self.norm is not None: + output = self.norm(output) + return output + + +def enable_faster_encoder(self): + """ + Compiles fusion encoder operator intergrated FasterTransformer using the + method of JIT(Just-In-Time) and replaces the `forward` function of + `paddle.nn.TransformerEncoder` and `paddle.nn.TransformerEncoderLayer` + objects inherited from `self` to support inference using FasterTransformer. + + Examples: + + .. code-block:: python + + from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder + + model.eval() + model = enable_faster_encoder(model) + enc_out = model(src, src_mask) + model = disable_faster_encoder(model) + """ + + def check_if_usable(layer): + for sub_layer in layer.children(): + if isinstance(sub_layer, + TransformerEncoderLayer) and sub_layer._config[ + 'bias_attr'] == False: + logger.warning("`False` for paddle.nn.TransformerEncoder's" \ + " parameter `bias_attr` is not supported in " \ + "FasterTransformer by now. Original Paddle API " \ + "would be called.") + return False + elif not check_if_usable(sub_layer): + return False + return True + + def init_func(layer): + if isinstance(layer, (TransformerEncoderLayer, TransformerEncoder)): + layer.forward = layer._ft_forward + + if not self.training: + if not check_if_usable(self): + return self + try: + load("FasterTransformer", verbose=True) + for layer in self.children(): + layer.apply(init_func) + except Exception: + logger.warning( + "Exception occurs when using Faster Transformer. " \ + "The original forward will be involved. ") + return self + + +def disable_faster_encoder(self): + """ + Restores the original `forward` function of `paddle.nn.TransformerEncoder` + and `paddle.nn.TransformerEncoderLayer` objects inherited from `self`. + + Examples: + + .. code-block:: python + + from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder + + model.eval() + model = enable_faster_encoder(model) + enc_out = model(src, src_mask) + model = disable_faster_encoder(model) + """ + + def init_func(layer): + if isinstance(layer, (TransformerEncoderLayer, TransformerEncoder)): + layer.forward = layer._ori_forward + + for layer in self.children(): + layer.apply(init_func) + return self diff --git a/paddlenlp/ops/patches/FasterTransformer/bert_encoder_transformer.h b/paddlenlp/ops/patches/FasterTransformer/bert_encoder_transformer.h new file mode 100644 index 000000000000..683df1ae0122 --- /dev/null +++ b/paddlenlp/ops/patches/FasterTransformer/bert_encoder_transformer.h @@ -0,0 +1,1229 @@ +/* + * Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * BERT Encoder transformer + **/ + +#pragma once + +#include +#include "fastertransformer/allocator.h" +#include "fastertransformer/common_structure.h" +#include "fastertransformer/cuda/cuda_int8_kernels.h" +#include "fastertransformer/cuda/cuda_kernels.h" +#include "fastertransformer/cuda/open_attention.h" +#include "fastertransformer/gemm_test/encoder_gemm_func.h" +#include "fastertransformer/gemm_test/encoder_igemm_func.h" + +namespace fastertransformer { + +template +class EncoderInitParam { +public: + const T *from_tensor = nullptr; + const T *to_tensor = nullptr; + + AttentionWeight self_attention; + const T *attr_mask = nullptr; + LayerNormWeight self_layernorm; + + FFNWeight ffn; + LayerNormWeight ffn_layernorm; + + T *transformer_out; + cublasHandle_t cublas_handle = nullptr; + cublasLtHandle_t cublaslt_handle = nullptr; + cudaStream_t stream = 0; + + const int *sequence_id_offset = nullptr; + int valid_word_num = -1; + int layer_idx = 0; + int layer_num = 12; + + // First 80 are for activation amaxs. + // For each activation amax, there are 4 values: amax, amax/127.0f, + // amax/127.0f/127.0f, 127.0f/amax -- input_amax 0-3 , Q_aftergemm_amax 4-7, + // Qbias_amax 8-11, K_aftergemm_amax 12-15, Kbias_amax 16-19, V_aftergemm_amax + // 20-23, Vbias_amax 24-27, bmm1_amax 28-31, Softmax_amax 32-35, bmm2_amax + // 36-39, Proj_aftergemm_scale 40-43, ProjBiasNorm_amax 44-47, + // FC1_aftergemm_amax 48-51, F1Bias_amax 52-55, FC2_aftergemm_amax 56-59, + // F2BiasNorm_amax 60-63, reserve 64-79 + // following by kernel amaxs : query_weight_amax_list, key_weight_amax_list, + // value_weight_amax_list, proj_weight_amax_list, FC1_weight_amax_list, + // FC2_weight_amax_list + // following by int8 gemm deQ scale list: Q_deQ_scale, K_deQ_scale, + // V_deQ_scale, bmm1_deQ_scale, bmm2_deQ_scale, FC0_deQ_scale, FC1_deQ_scale, + // FC2_deQ_scale + const float *amaxList = nullptr; + const int *trt_seqlen_offset = nullptr; + int trt_seqlen_size = -1; +}; + +template class MultiHeadAttention_> +class BertEncoderTransformerTraits; + +template