Skip to content

Enable FP16 activations in MatMulNBits#341

Closed
qti-mattsinc wants to merge 1 commit into
mainfrom
dev/mattsinc/matmulnbits-fp16
Closed

Enable FP16 activations in MatMulNBits#341
qti-mattsinc wants to merge 1 commit into
mainfrom
dev/mattsinc/matmulnbits-fp16

Conversation

@qti-mattsinc

Copy link
Copy Markdown
Collaborator

Description

  • Remove FP32 input restriction in MatMulNBits op builder. Note that the scales initializer must be cast to FP32 in the op builder as QNN currently requires FP32 scales at the API level.
  • Add FP16 MatMulNBits unit tests.

Motivation and Context

  • Enable w4a16 LLMs on the GPU for faster inferencing.

### Description
* Remove FP32 input restriction in MatMulNBits op
   builder. Note that the scales initializer must
   be cast to FP32 in the op builder as QNN currently
   requires FP32 scales at the API level.
* Add FP16 MatMulNBits unit tests.

### Motivation and Context
* Enable w4a16 LLMs on the GPU for faster inferencing.

@minfhong-qti minfhong-qti left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A gentle heads up I'm working on supporting MatMulNBits for HTP in PR #288. There will be a slight refactory. Feel free to review that PR.

Comment on lines +257 to +267
const OrtTypeInfo* type_info = nullptr;
const auto& ort_api = qnn_model_wrapper.GetOrtApi();
ORT_CXX_RETURN_ON_API_FAIL(ort_api.GetValueInfoTypeInfo(scale_tensor_proto, &type_info));
const OrtTensorTypeAndShapeInfo* tensor_type_and_shape_info = nullptr;
ORT_CXX_RETURN_ON_API_FAIL(ort_api.CastTypeInfoToTensorInfo(type_info, &tensor_type_and_shape_info));
ONNXTensorElementDataType onnx_data_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ORT_CXX_RETURN_ON_API_FAIL(ort_api.GetTensorElementType(tensor_type_and_shape_info, &onnx_data_type));

RETURN_IF(onnx_data_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
onnx_data_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
"Unsupported scales datatype");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const OrtTypeInfo* type_info = nullptr;
const auto& ort_api = qnn_model_wrapper.GetOrtApi();
ORT_CXX_RETURN_ON_API_FAIL(ort_api.GetValueInfoTypeInfo(scale_tensor_proto, &type_info));
const OrtTensorTypeAndShapeInfo* tensor_type_and_shape_info = nullptr;
ORT_CXX_RETURN_ON_API_FAIL(ort_api.CastTypeInfoToTensorInfo(type_info, &tensor_type_and_shape_info));
ONNXTensorElementDataType onnx_data_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ORT_CXX_RETURN_ON_API_FAIL(ort_api.GetTensorElementType(tensor_type_and_shape_info, &onnx_data_type));
RETURN_IF(onnx_data_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
onnx_data_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
"Unsupported scales datatype");
RETURN_IF(scales_tensor.type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT &&
scale_tensor.type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
"Unsupported scales datatype");

@qti-mattsinc

Copy link
Copy Markdown
Collaborator Author

Closing; this change was brought into a related PR: #288.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants