pytorch 添加一个维度
使用 unsqueeze 函数即可
import torch
x = torch.tensor([1, 2, 3, 4])
# 在外面增加一个维度
x.unsqueeze(0)
输出:
tensor([[1, 2, 3, 4]])
如果给里面的每个元素都扩展一维:
x.unsqueeze(1)
输出:
tensor([[1],
        [2],
        [3],
        [4]])
使用 unsqueeze 函数即可
import torch
x = torch.tensor([1, 2, 3, 4])
# 在外面增加一个维度
x.unsqueeze(0)
输出:
tensor([[1, 2, 3, 4]])
如果给里面的每个元素都扩展一维:
x.unsqueeze(1)
输出:
tensor([[1],
        [2],
        [3],
        [4]])