Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/semantic_indexing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,32 @@ python -u -m paddle.distributed.launch --gpus "0" \
0.9800204038619995
```

## 使用Faster Transformer进行快速预测
不同于上述原生预测的是,使用FasterTransformer的预测是使用了集成了Faster Transformer库的Paddle自定算子,在一定的配置下,可以对TransformerEncoder的预测进行加速。

```shell
python -u -m paddle.distributed.launch --gpus "0" faster_predict.py \
--init_from_params "batch_neg_v1.0/model_state.pdparams" \
--output_emb_size 256 \
--batch_size 32 \
--max_seq_length 64 \
--text_pair_file ${your_input_file} \

```

执行上述操作后,可以得到与原生预测非常接近的结果,在float32下,且batch_size=32, max_seq_len=64时,二种方式预测下最终余弦相似度的最大绝对误差约为3.93e-6。

通过比较,可以得到在不同batch_size, max_seq_len下,使用集成了FasterTransformer的高性能算子可以对Encoder部分的推理进行加速(其余参数都与默认值相同)。在NVIDIA Tesla V100,16GB的机器上,使用单卡预测得到部分性能数据如下,从表中可以看出在更小的batch_size和max_seq_len上,使用FasterTransformer预测更有优势。

| batch size | max_seq_len | FT加速算子(单位:s) | Paddle原生(单位:s) |
| ---------- | ----------- | ------------------- | ------------------- |
| 16 | 16 | 22.645333290100098 | 51.55912470817566 |
| 16 | 32 | 27.326106071472168 | 57.17143130302429 |
| 16 | 64 | 33.31318140029907 | 52.44770574569702 |
| 32 | 16 | 12.891342163085938 | 22.621662139892578 |
| 32 | 32 | 17.206310987472534 | 22.18772006034851 |


## 模型介绍
简要介绍 In-batch negatives 策略和 HardestNeg 策略思路

Expand Down
12 changes: 8 additions & 4 deletions examples/semantic_indexing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def create_dataloader(dataset,
trans_fn=None):
if trans_fn:
dataset = dataset.map(trans_fn)

shuffle = True if mode == 'train' else False
if mode == 'train':
batch_sampler = paddle.io.DistributedBatchSampler(
Expand All @@ -42,7 +41,10 @@ def create_dataloader(dataset,
return_list=True)


def convert_example(example, tokenizer, max_seq_length=512):
def convert_example(example,
tokenizer,
max_seq_length=512,
pad_to_max_seq_len=False):
"""
Builds model inputs from a sequence.

Expand All @@ -65,11 +67,13 @@ def convert_example(example, tokenizer, max_seq_length=512):

result = []
for key, text in example.items():
encoded_inputs = tokenizer(text=text, max_seq_len=max_seq_length)
encoded_inputs = tokenizer(
text=text,
max_seq_len=max_seq_length,
pad_to_max_seq_len=pad_to_max_seq_len)
input_ids = encoded_inputs["input_ids"]
token_type_ids = encoded_inputs["token_type_ids"]
result += [input_ids, token_type_ids]

return result


Expand Down
222 changes: 222 additions & 0 deletions examples/semantic_indexing/faster_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from functools import partial
import argparse
from pprint import pprint
import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import TransformerEncoder, TransformerEncoderLayer

from paddlenlp.transformers import ErnieTokenizer, ErnieModel
from paddlenlp.data import Pad, Tuple
from paddlenlp.datasets import load_dataset
from paddlenlp.ops import enable_faster_encoder, disable_faster_encoder

from data import read_text_pair, convert_example, create_dataloader


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text_pair_file",
type=str,
required=True,
help="The full path of input file")
parser.add_argument(
"--output_emb_size",
default=None,
type=int,
help="output_embedding_size")
parser.add_argument(
"--params_path",
type=str,
required=True,
help="The path to model parameters to be loaded.")
parser.add_argument(
"--max_seq_length",
default=64,
type=int,
help="The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter will be padded."
)
parser.add_argument(
"--dropout", default=0.0, type=float, help="Dropout probability.")
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument(
"--pad_to_max_seq_len",
action="store_true",
help="Whether to pad to max_seq_len.")

args = parser.parse_args()
return args


class SemanticIndexingPredictor(nn.Layer):
def __init__(self,
pretrained_model,
output_emb_size,
n_layer=12,
n_head=12,
hidden_size=768,
dim_feedforward=3072,
activation="relu",
bos_id=0,
dropout=0,
max_seq_len=128,
is_gelu=False):
super(SemanticIndexingPredictor, self).__init__()
size_per_head = hidden_size // n_head
self.bos_id = bos_id
self.ptm = pretrained_model
self.dropout = nn.Dropout(dropout if dropout is not None else 0.0)
self.output_emb_size = output_emb_size
if output_emb_size > 0:
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
self.emb_reduce_linear = paddle.nn.Linear(
768, output_emb_size, weight_attr=weight_attr)
encoder_layer = TransformerEncoderLayer(
hidden_size, n_head, dim_feedforward, dropout=dropout)
self.ptm.encoder = TransformerEncoder(encoder_layer, n_layer)

def get_pooled_embedding(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
src_mask = (input_ids != self.bos_id
).astype(self.ptm.encoder.layers[0].norm1.bias.dtype)
src_mask = paddle.unsqueeze(src_mask, axis=[1, 2])
src_mask.stop_gradient = True

ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=1)
position_ids = seq_length - ones
position_ids.stop_gradient = True

embedding_output = self.ptm.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
sequence_output = self.ptm.encoder(embedding_output, src_mask)
cls_embedding = self.ptm.pooler(sequence_output)

if self.output_emb_size > 0:
cls_embedding = self.emb_reduce_linear(cls_embedding)
cls_embedding = self.dropout(cls_embedding)
cls_embedding = F.normalize(cls_embedding, p=2, axis=-1)

return cls_embedding

def forward(self,
query_input_ids,
title_input_ids,
query_token_type_ids=None,
query_position_ids=None,
query_attention_mask=None,
title_token_type_ids=None,
title_position_ids=None,
title_attention_mask=None):
query_cls_embedding = self.get_pooled_embedding(
query_input_ids, query_token_type_ids, query_position_ids,
query_attention_mask)
title_cls_embedding = self.get_pooled_embedding(
title_input_ids, title_token_type_ids, title_position_ids,
title_attention_mask)
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding,
axis=-1)
return cosine_sim

def load(self, init_from_params):
if init_from_params and os.path.isfile(init_from_params):
state_dict = paddle.load(init_from_params)
self.set_state_dict(state_dict)
print("Loaded parameters from %s" % init_from_params)
else:
raise ValueError(
"Please set --params_path with correct pretrained model file")


def do_predict(args):
place = paddle.set_device("gpu")
paddle.seed(args.seed)
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')

trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
pad_to_max_seq_len=args.pad_to_max_seq_len)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
): [data for data in fn(samples)]

valid_ds = load_dataset(
read_text_pair, data_path=args.text_pair_file, lazy=False)

valid_data_loader = create_dataloader(
valid_ds,
mode="predict",
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)

pretrained_model = ErnieModel.from_pretrained("ernie-1.0")

model = SemanticIndexingPredictor(
pretrained_model,
args.output_emb_size,
max_seq_len=args.max_seq_length,
dropout=args.dropout)
model.eval()
model.load(args.params_path)
model = enable_faster_encoder(model)
cosine_sims = []
for batch_data in valid_data_loader:
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch_data
query_input_ids = paddle.to_tensor(query_input_ids)
query_token_type_ids = paddle.to_tensor(query_token_type_ids)
title_input_ids = paddle.to_tensor(title_input_ids)
title_token_type_ids = paddle.to_tensor(title_token_type_ids)
batch_cosine_sim = model(
query_input_ids=query_input_ids,
title_input_ids=title_input_ids,
query_token_type_ids=query_token_type_ids,
title_token_type_ids=title_token_type_ids).numpy()
cosine_sims.append(batch_cosine_sim)

cosine_sims = np.concatenate(cosine_sims, axis=0)
for cosine in cosine_sims:
print('{}'.format(cosine))
model = disable_faster_encoder(model)


if __name__ == "__main__":
args = parse_args()
pprint(args)
do_predict(args)
4 changes: 3 additions & 1 deletion examples/semantic_indexing/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument("--pad_to_max_seq_len", action="store_true", help="Whether to pad to max seq length.")
args = parser.parse_args()
# yapf: enable

Expand Down Expand Up @@ -86,7 +87,8 @@ def predict(model, data_loader):
trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)
max_seq_length=args.max_seq_length,
pad_to_max_seq_len=args.pad_to_max_seq_len)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
Expand Down
20 changes: 18 additions & 2 deletions paddlenlp/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ option(WITH_TRANSFORMER "Compile with Transformer"
option(WITH_GPT "Compile with GPT" OFF)
option(WITH_UNIFIED "Compile with Unified Transformer" ON)
option(WITH_DECODER "Compile with Transformer Decoder" ON)
option(WITH_ENCODER "Compile with Transformer Encoder" ON)

if(NOT WITH_GPU)
message(FATAL_ERROR "Faster transformer custom op doesn't support CPU. Please add the flag -DWITH_GPU=ON to use GPU. ")
Expand All @@ -44,12 +45,16 @@ if(WITH_UNIFIED)
list(APPEND decoding_op_files fusion_unified_decoding_op.cc fusion_unified_decoding_op.cu)
endif()

if(WITH_ENCODER)
list(APPEND decoding_op_files fusion_encoder_op.cc fusion_encoder_op.cu)
endif()

if(WITH_DECODER)
list(APPEND decoder_op_files fusion_decoder_op.cc fusion_decoder_op.cu)
endif()

if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON must be set to use FasterTransformer. ")
if(NOT WITH_TRANSFORMER AND NOT WITH_GPT AND NOT WITH_DECODER AND NOT WITH_ENCODER)
message(FATAL_ERROR "-DWITH_TRANSFORMER=ON or/and -DWITH_GPT=ON or/and -DWITH_DECODER=ON or/and -DWITH_ENCODER=ON must be set to use FasterTransformer. ")
endif()

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
Expand Down Expand Up @@ -161,6 +166,13 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/topk_kernels.cuh topk_kernels_cuh_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/topk_kernels.cuh topk_kernels_cuh_dst)

file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/open_attention.h open_attention_h_dst)

file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/bert_encoder_transformer.h bert_encoder_transformer_h_src)
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)

set(OPT_OPEN_ATTN_COMMAND sed -i -e "370,392d" -e "410,454d" -e "229d" ${open_attention_h_dst})
#set(OPT_BERT_ENCODER_COMMAND sed -i -e "552,592d" -e "118a bool is_gelu_=true;" ${bert_encoder_transformer_h_dst})

# TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it
set(MUTE_COMMAND grep -rl "printf(\"\\[WARNING\\]" ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/ | xargs -i{} sed -i "s/printf(\"\\WWARNING\\W decoding[^)]\\{1,\\})/ /" {})
Expand All @@ -174,6 +186,7 @@ set(FT_PATCH_COMMAND
&& cp ${beamsearch_h_src} ${trans_dst}
&& cp ${sampling_h_src} ${trans_dst}
&& cp ${arguments_h_src} ${trans_dst}
&& cp ${bert_encoder_transformer_h_src} ${bert_encoder_transformer_h_dst}
&& cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
&& cat ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
&& cat ${decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
Expand All @@ -182,6 +195,7 @@ set(FT_PATCH_COMMAND
&& cat ${trans_decoder_h_src} >> ${open_decoder_h_dst}
&& cat ${trans_cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
&& cat ${trans_decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
&& ${OPT_OPEN_ATTN_COMMAND}
&& ${MUTE_COMMAND}
)

Expand Down Expand Up @@ -282,3 +296,5 @@ if(ON_INFER AND WITH_GPT)
endif()

add_subdirectory(faster_transformer)


7 changes: 7 additions & 0 deletions paddlenlp/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from .faster_transformer.transformer.decoding import *
from .faster_transformer.transformer.faster_transformer import *
from .faster_transformer.transformer.decoder import *
from .faster_transformer.transformer.encoder import *
from .einsum import *
from .distributed import *
from . import optimizer

paddle.nn.TransformerEncoderLayer._ft_forward = encoder_layer_forward
paddle.nn.TransformerEncoder._ft_forward = encoder_forward

paddle.nn.TransformerEncoderLayer._ori_forward = paddle.nn.TransformerEncoderLayer.forward
paddle.nn.TransformerEncoder._ori_forward = paddle.nn.TransformerEncoder.forward
Loading