Skip to content
This repository was archived by the owner on May 29, 2026. It is now read-only.
This repository was archived by the owner on May 29, 2026. It is now read-only.

Issue with use_fp16=True Leading to Type Conversion Error in unet.py #56

Description

@DuoLi1999

When setting use_fp16=False, the code functions correctly. However, an issue arises with use_fp16=True due to an unexpected type conversion in unet.py(line435).

The problem occurs at line 435, where the tensor a is converted from float16 to float32:

a = a.float()

Prior to this line, a is in float16, but after this line, it is converted to float32. If we remove or comment out this line, the code encounters an error. It seems that maintaining a in float16 is essential for the use_fp16=True setting to work correctly, but the current implementation inadvertently converts it to float32, leading to issues.

Additionally, I've noticed that the current code has been modified to prevent the utilization of flash attention. I also attempted to run the original version, but encountered similar errors.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions