Update Habana environment to 1.20.0 and PyTorch 2.6.0#406
Update Habana environment to 1.20.0 and PyTorch 2.6.0#406jmduarte wants to merge 19 commits intojpata:mainfrom
Conversation
|
So habana software requires Now I get past that issue, but I notice when I try to covert the dataset with |
|
The tfds documentation says that "TensorFlow is no longer a dependency to read datasets": https://www.tensorflow.org/datasets/tfless_tfds The CI job also tries to create the dataset, so potentially we'd want to skip the CI for the habana branch. In any case, to run the ML training, you would use a pre-existing dataset. |
|
OK thanks! I managed to install a compatible tf, but i agree it's not necessary since I can probably read a pre-made dataset. I now see another minor error during validation/plotting, but the training runs (on CPU). Will try to get the training to run on HPU now. |
|
made it to this error, but I actually don't see where it's coming from at the moment. It seems Traceback (most recent call last):
File "/particleflow/mlpf/pipeline.py", line 160, in <module>
main()
File "/particleflow/mlpf/pipeline.py", line 156, in main
device_agnostic_run(config, world_size, experiment_dir, args.habana)
File "/particleflow/mlpf/model/training.py", line 855, in device_agnostic_run
run(rank, world_size, config, outdir, logfile)
File "/particleflow/mlpf/model/training.py", line 740, in run
train_all_epochs(
File "/particleflow/mlpf/model/training.py", line 369, in train_all_epochs
losses_train = train_epoch(
File "/particleflow/mlpf/model/training.py", line 144, in train_epoch
loss_opt, loss, _, _, _ = model_step(batch, model, mlpf_loss)
File "/particleflow/mlpf/model/training.py", line 75, in model_step
loss_opt, losses_detached = loss_fn(ytarget, ypred, batch)
File "/particleflow/mlpf/model/losses.py", line 115, in mlpf_loss
was_input_true = torch.concat([torch.nn.functional.one_hot((y["cls_id"] != 0).to(torch.long)), y["momentum"]], axis=-1) * batch.mask.unsqueeze(
RuntimeError: Number of classes cannot be -1 |
Supersedes #309