-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlmc_ops.cpp
More file actions
22 lines (19 loc) · 850 Bytes
/
lmc_ops.cpp
File metadata and controls
22 lines (19 loc) · 850 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <pybind11/pybind11.h>
#include "mem_kernels.cuh"
#include <torch/torch.h>
#include <iostream>
// void launch_my_kernel(); // declared somewhere
PYBIND11_MODULE(lmc_ops, m) {
m.def("multi_layer_kv_transfer", &multi_layer_kv_transfer);
m.def("multi_layer_kv_transfer_cudaMemcpy", &multi_layer_kv_transfer_cudaMemcpy);
m.def("single_layer_kv_transfer", &single_layer_kv_transfer);
m.def("load_and_reshape_flash", &load_and_reshape_flash);
m.def("reshape_and_cache_back_flash", &reshape_and_cache_back_flash);
}
// TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// // Cache ops
// // Swap in (out) the cache blocks from src to dst.
// cache_ops.def(
// "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
// cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// }