-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Description
New Issue Checklist
- [×] I have read the Contribution Guidelines
- [×] I searched for existing GitHub issues
Issue Description
Pytorch后端模型定义NHWC和NCHW数据格式主要是定以数据和模型后传到设备时用.to("cuda:0", memory_format=torch.channels_last)确定。
TLX目前做法是pytorch依据nhwc格式时,全部转NCHW然后处理完转回来,这潜在是让模型用NCHW格式计算。对纯GPU应用时问题不大,但是对于一些NHWC友好的设备部署,比如未来的Mindspore,由于多次nhwc nchw切换,性能有损失。
这里可能需要框架对于pytorch这里nhwc支持改成全局变量,即输入时数据做nchw-nhwc,模型转nhwc然后计算即可。
不过Pytorch本身GPU NHWC支持稀烂,倒不是很急。
Metadata
Metadata
Assignees
Labels
No labels