Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cuda/fastermoe/smart_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
long global_batch_size,
long expert_size,
long n_workers,
int output_dim,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
Expand All @@ -95,8 +96,8 @@ std::vector<torch::Tensor> _smart_sch_forward(

// TODO: maybe empty is faster
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, output_dim});
auto output_buf = input_buf.new_zeros({input_buf.size(0), output_dim});

std::vector<torch::Tensor> params;
auto stored_models_ = stored_models.data_ptr<bool>();
Expand All @@ -110,24 +111,22 @@ std::vector<torch::Tensor> _smart_sch_forward(
}
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16,input_buf.scalar_type(),
"fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl(
forward_fn,
stash_fn,
pop_fn,
input_buf.device(),
params,

input_buf.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),

local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers, expert_size,
d_model, output_dim, num_expert, rank, n_workers, expert_size,
pipeline_gran, smgr);
}));
return {output_buf, global_input_buf};
Expand All @@ -141,6 +140,7 @@ torch::Tensor _smart_sch_backward(
long buf_batch_size,
long global_batch_size,
long n_workers,
int output_dim,
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
Expand All @@ -155,7 +155,7 @@ torch::Tensor _smart_sch_backward(
auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, grad_out.scalar_type(),
"fmoe_cuda_smartsch_backward", ([&] {
fmoe_cuda_fused_backward_impl(
backward_fn,
Expand All @@ -173,7 +173,7 @@ torch::Tensor _smart_sch_backward(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers,
d_model, output_dim, num_expert, rank, n_workers,
pipeline_gran, smgr);
}));
return grad_in;
Expand Down
16 changes: 9 additions & 7 deletions cuda/fastermoe/smart_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
template<typename scalar_t>
void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model, int output_dim,
CudaStreamManager* smgr) {
if(micro_batch_size == 0) {
return;
Expand All @@ -94,8 +94,8 @@ void computeFn(py::function fn, c10::Device device,
.requires_grad(true);
auto inp = torch::from_blob(inp_buf + offset * d_model,
{micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * output_dim,
{micro_batch_size, output_dim}, options);
smgr->use_default = true;
fn(inp, oup, expert_idx, store_idx);
smgr->use_default = false;
Expand All @@ -120,6 +120,7 @@ void fmoe_cuda_fused_forward_impl(
const bool* stored_models,

long d_model,
int output_dim,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Expand Down Expand Up @@ -196,7 +197,7 @@ void fmoe_cuda_fused_forward_impl(
(from_base + pipeline_gran)] - offset;
computeFn(forward_fn, device,
global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, output_dim,smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
Expand All @@ -210,7 +211,7 @@ void fmoe_cuda_fused_forward_impl(
long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device,
input_buf, output_buf,
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, output_dim,smgr);
++si;
}
}
Expand Down Expand Up @@ -271,6 +272,7 @@ void fmoe_cuda_fused_backward_impl(
const long* global_expert_count,
const bool* stored_models,
long d_model,
int output_dim,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
Expand Down Expand Up @@ -324,7 +326,7 @@ void fmoe_cuda_fused_backward_impl(
long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device,
grad_out, grad_in,
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, output_dim,smgr);
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
Expand All @@ -346,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(

computeFn(backward_fn, device,
global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, output_dim, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
Expand Down
2 changes: 2 additions & 0 deletions cuda/fmoe_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
long global_batch_size,
long expert_size,
long n_workers,
int output_dim,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
Expand All @@ -85,6 +86,7 @@ torch::Tensor _smart_sch_backward(
long buf_batch_size,
long global_batch_size,
long n_workers,
int output_dim,
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
Expand Down
5 changes: 3 additions & 2 deletions cuda/parallel_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ torch::Tensor _linear_forward(
output = torch::empty({batch_size, out_feat}, out_options);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
// AT_DISPATCH_CASE_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, input_buf.scalar_type(), "moe_forward_cuda",
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, input_buf.scalar_type(), "moe_forward_cuda",
([&] {
fmoe_cuda_linear_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
Expand Down Expand Up @@ -72,7 +73,7 @@ std::vector<torch::Tensor> _linear_backward(
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, input_buf.scalar_type(), "moe_cuda_backward", ([&] {
fmoe_cuda_linear_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
Expand Down
4 changes: 4 additions & 0 deletions cuda/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
#define CUDA_STREAM_MANAGER_H

#include "utils/helper_cuda.h"
#include <iostream>

#ifdef FMOE_USE_NCCL
#include <nccl.h>

#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
std::cout<< __res__ <<std::endl; \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
fprintf(stdout, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
fflush(stderr); \
exit(-1); \
} \
}
Expand Down
30 changes: 30 additions & 0 deletions cuda/utils/cublas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
Expand Down Expand Up @@ -84,6 +85,35 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
#endif
}

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const c10::BFloat16 *alpha,
const c10::BFloat16 *A, int lda,
const c10::BFloat16 *B, int ldb,
const c10::BFloat16 *beta,
c10::BFloat16 *C, int ldc) {
// return cublasHgemm(handle, transa, transb, m, n, k,
// (const __half*)alpha,
// (const __half*)A, lda,
// (const __half*)B, ldb,
// (const __half*)beta,
// (__half*)C, ldc);


return cublasGemmEx(handle, transa, transb, m, n, k,
(const void*)alpha,
(const void*)A, CUDA_R_16BF, lda,
(const void*)B, CUDA_R_16BF, ldb,
(const void*)beta,
(void*)C, CUDA_R_16BF, ldc,
CUDA_R_16BF,
CUBLAS_GEMM_DEFAULT
);
}



inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
Expand Down
13 changes: 8 additions & 5 deletions fmoe/fastermoe/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(
num_expert,
world_size):
local_input_buf = _local_scatter(inp, pos_s)

# import pdb;pdb.set_trace()
ctx.gibs = [None] * (world_size * num_expert * 2)
ctx.gobs = [None] * (world_size * num_expert * 2)
def _expert_forward(x, y, expert_idx, store_idx):
Expand All @@ -44,6 +44,7 @@ def _expert_forward(x, y, expert_idx, store_idx):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
ctx.gibs[store_idx] = x
ctx.gobs[store_idx] = y0
# import pdb;pdb.set_trace()
y.copy_(y0)

ctx.experts = experts
Expand All @@ -59,12 +60,12 @@ def _expert_forward(x, y, expert_idx, store_idx):
def stash_fn(params, store_idx, expert_idx):
expert_utils.stash_expert_params(experts, params, expert_idx)
ctx.shadows[store_idx] = params

output_dim = experts[0].output_dim
local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf,
local_expert_count, global_expert_count,
stored_models, fwd_batch_size, ctx.expert_size,
world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)
world_size, output_dim, _expert_forward, get_param_fn, stash_fn, pop_fn)

out = _local_gather(local_output_buf, pos_g, out_batch_size,
maybe_overlap=False)
Expand Down Expand Up @@ -92,6 +93,8 @@ def _expert_backward(grad_y, grad_x, expert_idx, store_idx):
grad_x.copy_(x.grad)

experts = ctx.experts
output_dim = experts[0].output_dim

def stash_fn(store_idx, expert_idx):
expert_utils.stash_expert_params(experts, ctx.shadows[store_idx], expert_idx)
pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
Expand All @@ -107,7 +110,7 @@ def collect_fn(store_idx, root, expert_idx):
local_expert_count, global_expert_count,
stored_models,
pos_s.shape[0], fwd_batch_size,
world_size,
world_size,output_dim,
_expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)

Expand All @@ -127,7 +130,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(gate, n_expert, world_size)

# import pdb;pdb.set_trace()
global policy_fn
if policy_fn is None:
policy_fn = get_shadow_policy(d_model=inp.shape[-1])
Expand Down
1 change: 1 addition & 0 deletions fmoe/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_moe_group():

def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad():
gate = gate.long()
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32
)
Expand Down
Loading