Skip to content

Pytorch后端NHWC和NCHW问题 #69

@Windaway

Description

@Windaway

New Issue Checklist

Issue Description

Pytorch后端模型定义NHWC和NCHW数据格式主要是定以数据和模型后传到设备时用.to("cuda:0", memory_format=torch.channels_last)确定。

TLX目前做法是pytorch依据nhwc格式时,全部转NCHW然后处理完转回来,这潜在是让模型用NCHW格式计算。对纯GPU应用时问题不大,但是对于一些NHWC友好的设备部署,比如未来的Mindspore,由于多次nhwc nchw切换,性能有损失。

这里可能需要框架对于pytorch这里nhwc支持改成全局变量,即输入时数据做nchw-nhwc,模型转nhwc然后计算即可。

不过Pytorch本身GPU NHWC支持稀烂,倒不是很急。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions