Skip to content

model precision #69

@Fans0014

Description

@Fans0014

I'm trying to use the model from this link(https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr1024/vis_model.pth) with precision torch.float16. However, I encountered a nan result. By debugging the model inference process, I found that some activation values in the model exceeded the maximum value of torch.float16

With using torch.float32, I inserted a print function
print(">>>Debug unet-620: ", x.mean(), temb.mean())
into the https://github.com/apple/ml-mdm/blob/main/ml-mdm-matryoshka/ml_mdm/models/unet.py#L543

And I got the below logs

Debug unet-620: tensor(-0.0274, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(1145381., device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(142.4138, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-22.8761, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-3.5595, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-0.2938, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-0.3821, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.0356, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.1432, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.1935, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.7042, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.0997, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.4697, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-1.0218, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-6.7193, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-9.7318, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-3.2291, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-137.9206, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-129216.2422, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-672989.1250, device='cuda:0') tensor(1961.6932, device='cuda:0')
====================================

It's clear that some values exceed 65504(ma value for torch.float16). Is there any way I can finetune this model to reduce the intermediate activation values so that it can run with torch.float16? Alternatively, could you please provide a new model with appropriate activation values that support torch.float16?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions