1 | # PyTorch 张量基础 |
张量属性
1 | print(f"Shape of tensor: {rand_tensor.shape}") # 形状 |
索引和切片
1 | tensor = torch.ones(4, 4) |
张量连接
1 | t1 = torch.cat([tensor, tensor, tensor], dim=1) # 沿维度1连接 |
张量运算
元素级乘法
1 | # 两种等价写法 |
矩阵乘法
1 | # 两种等价写法 |
原地操作(In-place)
1 | tensor.add_(5) # 带_后缀的操作会修改原张量 |
注意:原地操作可以节省内存,但在计算导数时可能会出现问题,因此不鼓励使用。
与NumPy互转
张量转NumPy数组
1 | t = torch.ones(5) |
NumPy数组转张量
1 | n = np.ones(5) |