Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions paddlenlp/trainer/trainer_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def _dynabert(self, model):
# TODO: args.gradient_accumulation_steps
if args.max_steps > 0:
args.num_training_steps = args.max_steps
args.num_train_epochs = math.ceil(args.num_training_steps / len(train_dataloader))
args.num_train_epochs = math.ceil(args.num_train_epochs)
# args.num_train_epochs = args.nun_train_epoh
# import pdb; pdb.set_trace()
# if args.
# args.num_train_epochs = math.ceil(args.num_training_steps / len(train_dataloader))
else:
args.num_training_steps = len(train_dataloader) * args.num_train_epochs
args.num_train_epochs = math.ceil(args.num_train_epochs)
Expand Down Expand Up @@ -329,6 +333,8 @@ def check_dynabert_config(net_config, width_mult):
# before.
elif "out_proj" in key or "linear2" in key:
net_config[key]["expand_ratio"] = 1.0
elif "classifier" in key:
net_config[key]["expand_ratio"] = 1.0
return net_config


Expand Down Expand Up @@ -717,7 +723,7 @@ def _quant_aware_training_dynamic(self, input_dir):
"dtype": "int8",
# window size for 'range_abs_max' quantization. defaulf is 10000
"window_size": 10000,
"quantizable_layer_type": ["Linear", "Conv2D"],
"quantizable_layer_type": ["Linear", "Conv2D", "Matmul"],
"moving_rate": args.moving_rate,
"onnx_format": args.onnx_format,
}
Expand All @@ -733,7 +739,7 @@ def _quant_aware_training_dynamic(self, input_dir):
# TODO: args.gradient_accumulation_steps
if args.max_steps > 0:
args.num_training_steps = args.max_steps
args.num_train_epochs = math.ceil(args.num_training_steps / len(train_dataloader))
args.num_train_epochs = math.ceil(args.num_train_epochs)
else:
args.num_training_steps = len(train_dataloader) * args.num_train_epochs
args.num_train_epochs = math.ceil(args.num_train_epochs)
Expand Down Expand Up @@ -810,11 +816,19 @@ def _quant_aware_training_dynamic(self, input_dir):
)
paddle.save(model_to_save.state_dict(), output_param_path)
logger.info("eval done total: %s s" % (time.time() - tic_eval))
if global_step >= args.num_training_steps:
break
if global_step >= args.num_training_steps:
break
logger.info("Best result: %.4f" % best_acc)
self.model.set_state_dict(paddle.load(output_param_path))

input_spec = generate_input_spec(self.model, self.train_dataset, self.args.input_dtype)
logger.info("Load parameters from: %s" % output_param_path)
self.model.set_state_dict(paddle.load(output_param_path))

input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int32", name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype="int32", name="short_session_input_ids"),
]
quanter.save_quantized_model(
self.model, os.path.join(input_dir, args.output_filename_prefix), input_spec=input_spec
)
Expand Down Expand Up @@ -1018,6 +1032,7 @@ def cut_embeddings(model, tokenizer, config, word_emb_index, max_seq_length, max
# Rewrites config
config["max_position_embeddings"] = max_seq_length
config["vocab_size"] = max_vocab_size
config["hidden_act"] = "relu6"
config.save_pretrained(output_dir)

# Rewrites vocab file
Expand Down