使用 unsqueeze
函数即可
import torch
x = torch.tensor([1, 2, 3, 4])
# 在外面增加一个维度
x.unsqueeze(0)
输出:
tensor([[1, 2, 3, 4]])
如果给里面的每个元素都扩展一维:
x.unsqueeze(1)
输出:
tensor([[1],
[2],
[3],
[4]])
使用 unsqueeze
函数即可
import torch
x = torch.tensor([1, 2, 3, 4])
# 在外面增加一个维度
x.unsqueeze(0)
输出:
tensor([[1, 2, 3, 4]])
如果给里面的每个元素都扩展一维:
x.unsqueeze(1)
输出:
tensor([[1],
[2],
[3],
[4]])