Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Correct Integration of TT-Embedding to DLRM #10

@TimJZ

Description

@TimJZ

Hi, I'm currently trying to integrate the TT-Embedding to the original DLRM code base, and I've successfully reproduced the result shown in readme. However, I'm not quite sure what are the essential changes to make.

Right now I'm replacing the original embeddingbag function (within the create_emb in dlrm_s_pytorch.py file) in DLRM with TTEmbeddingBag, but have trouble figuring out the correct parameters for the function. The parameters I used right now is:

               EE = TTEmbeddingBag(
                    n,
                    m,
                    tt_ranks=[12,14],
                    sparse=False,
                    use_cache=False,
                    weight_dist="uniform"
                )

I left the tt_p_shapes and tt_q_shapes to blank since each layer's embedding dimension and number of embeddings are different.
The paper mentioned that the TT-Rank used was [8, 16, 32, 64], but I wasn't able to use that parameter, since it would result in failure of passing assertion len(self.tt_p_shapes) <= 4. Therefore I used the same parameters in example ([12,14]).

And that result a CUDA illegal memory access error at line 174 in tt_embedding_ops. Full error message is attached below:

Traceback (most recent call last):
  File "dlrm_s_pytorch.py", line 1013, in <module>
    Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
  File "dlrm_s_pytorch.py", line 866, in dlrm_wrap
    return dlrm(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "dlrm_s_pytorch.py", line 385, in forward
    return self.parallel_forward(dense_x, lS_o, lS_i)
  File "dlrm_s_pytorch.py", line 470, in parallel_forward
    ly = self.apply_emb(lS_o, lS_i, self.emb_l)
  File "dlrm_s_pytorch.py", line 328, in apply_emb
    V = E(sparse_index_group_batch, sparse_offset_group_batch)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/dlrm/tt_embeddings_ops.py", line 801, in forward
    output = TTLookupFunction.apply(
  File "/mnt/dlrm/tt_embeddings_ops.py", line 174, in forward
    output = tt_embeddings.tt_forward(
RuntimeError: CUDA error: an illegal memory access was encountered
  1. I'm thinking this is caused by in correct parameters and wondering if anyone could help me out here.
  2. I'm also wondering if there's any additional changes need to be made to dlrm other than replacing the embeddingbag.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions