Skip to content

Shouldn't device type in model.py methods be torch.device instead of str? #246

@igodlab

Description

@igodlab

The from_pretrained and from_local methods with the @classmethod decorator in zonos/model.py take a device: str = DEFAULT_DEVICE argument with str as its type hint. However, DEFAULT_DEVICE and get_device defined in zonos/utils.py shows that they're actually of torch.device type:

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device(torch.cuda.current_device())
    # MPS breaks for whatever reason. Uncomment when it's working.
    # if torch.mps.is_available():
    #     return torch.device("mps")
    return torch.device("cpu")


DEFAULT_DEVICE = get_device()

Is the type hint of the methods' argument wrong?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions