Skip to content

RuntimeError: Tensor on device cuda:0 is not on the expected device meta! #737

@LucasMagnana

Description

@LucasMagnana

🐛 Bug

Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing RuntimeError: Tensor on device cuda:0 is not on the expected device meta! when using bert-base-cased.

To Reproduce

The code I use is the following, the model is a AutoModelForLanguageModelling from the transformers library :

def train(self, model, lr, train_dataset, eval_dataset, num_epochs):
        train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=self.config.train_batch_size,
            collate_fn=self.data_collator,
        )
        model = model.to(self.device)
        # Set the model to train mode (HuggingFace models load in eval mode)
        model = model.train()
        # Define optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

        DELTA = 1 / len(train_dataloader)

        privacy_engine = PrivacyEngine()

        model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=train_dataloader,
            target_delta=DELTA,
            target_epsilon=7.5,
            epochs=num_epochs,
            max_grad_norm=0.1,
        )

        for epoch in range(1, num_epochs+1):
            losses = []

            with BatchMemoryManager(
                data_loader=train_dataloader,
                max_physical_batch_size=4,
                optimizer=optimizer
            ) as memory_safe_data_loader:
                for step, batch in enumerate(tqdm(memory_safe_data_loader)):
                    optimizer.zero_grad()

                    inputs = {k: batch[k].to(self.device) for k in batch if k != "labels"}

                    outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

                    loss = outputs[0].mean()
                    loss.backward()
                    losses.append(loss.item())

                    optimizer.step()

                    if step > 0 and step % 5000 == 0:
                        train_loss = np.mean(losses)
                        eps = privacy_engine.get_epsilon(DELTA)

                        print(
                        f"Epoch: {epoch} | "
                        f"Step: {step} | "
                        f"Train loss: {train_loss:.3f} | "
                        f"ɛ: {eps:.2f}"
                        )

The full error :

Traceback (most recent call last):
  File "/home/lmagnana/nlp-attacks/examples/special_finetunings/n2c2_ner_mlm_finetuning.py", line 72, in <module>
    models, metrics = finetuner.run(dataset, test_size, epochs, pathlib.Path(output_dir), output_name=ouput_name)
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/Finetuner.py", line 248, in run
    model = self.train(model, self.config.learning_rate, ds["train"], ds["test"], epochs) 
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/PrivacyPreservingLanguageModelling.py", line 141, in train
    loss.backward()
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 72, in __call__
    return self.hook(module, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/grad_sample_module.py", line 340, in capture_backprops_hook
    grad_samples = grad_sampler_fn(module, activations, backprops)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
    per_sample_grads = layer.ft_compute_sample_grad(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 363, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1285, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1249, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
    output = flayer(params, batched_activations)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
    return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
    return module(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 787, in forward
    hidden_states = self.decoder(hidden_states)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
    result = fn(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 137, in _fn
    result = fn(**bound.arguments)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_refs/__init__.py", line 1091, in add
    output = prims.add(a, b)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_ops.py", line 594, in __call__
    return self_._op(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims/__init__.py", line 359, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/__init__.py", line 740, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!

Expected behavior

The code should work with both a gpt2 and a bert-base-cased model.

Environment

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.82
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.3.1
[pip3] opacus==1.5.3
[pip3] triton==2.3.1
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.5.82                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi

Thanks in advance for your replies.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions