einops测试

news/2025/2/23 20:42:51

文章目录

1. einops

einops 主要是通过爱因斯坦标记法来处理张量矩阵的库,让矩阵处理上非常简单。

  • conda :
python">conda install conda-forge::einops

2. code

python">import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    x = torch.arange(96).reshape((2, 3, 4, 4)).to(torch.float32)
    print(f"x.shape={x.shape}")
    print(f"x=\n{x}")

    # 1. 转置
    x_torch_trans = x.transpose(1, 2)
    x_einops_trans = rearrange(x, 'b i w h -> b w i h')
    x_check_trans = torch.allclose(x_torch_trans, x_einops_trans)
    print(f"x_torch_trans is {x_check_trans} same with x_einops_trans")

    # 2. 变形
    x_torch_reshape = x.reshape(6, 4, 4)
    x_einops_reshape = rearrange(x, 'b i w h -> (b i) w h')
    x_check_reshape = torch.allclose(x_torch_reshape, x_einops_reshape)
    print(f"x_einops_reshape is {x_check_reshape} same with x_check_reshape")

    # 3. image2patch
    image2patch = rearrange(x, 'b i (h1 p1) (w1 p2) -> b i (h1 w1) p1 p2', p1=2, p2=2)
    print(f"image2patch.shape={image2patch.shape}")
    print(f"image2patch=\n{image2patch}")
    image2patch2 = rearrange(image2patch, 'b i j h w -> b (i j) h w')
    print(f"image2patch2.shape={image2patch2.shape}")
    print(f"image2patch2=\n{image2patch2}")
    y = torch.arange(24).reshape((2, 3, 4)).to(torch.float32)
    y_einops_mean = reduce(y, 'b h w -> b h', 'mean')
    print(f"y=\n{y}")
    print(f"y_einops_mean=\n{y_einops_mean}")
    y_tensor = torch.arange(24).reshape(2, 2, 2, 3)
    y_list = [y_tensor, y_tensor, y_tensor]
    y_output = rearrange(y_list, 'n b i h w -> n b i h w')
    print(f"y_tensor=\n{y_tensor}")
    print(f"y_output=\n{y_output}")
    z_tensor = torch.arange(12).reshape(2, 2, 3).to(torch.float32)
    z_tensor_1 = rearrange(z_tensor, 'b h w -> b h w 1')
    print(f"z_tensor=\n{z_tensor}")
    print(f"z_tensor_1=\n{z_tensor_1}")
    z_tensor_2 = repeat(z_tensor_1, 'b h w 1 -> b h w 2')
    print(f"z_tensor_2=\n{z_tensor_2}")
    z_tensor_repeat = repeat(z_tensor, 'b h w -> b (2 h) (2 w)')
    print(f"z_tensor_repeat=\n{z_tensor_repeat}")
python">x.shape=torch.Size([2, 3, 4, 4])
x=
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.]],

         [[32., 33., 34., 35.],
          [36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],


        [[[48., 49., 50., 51.],
          [52., 53., 54., 55.],
          [56., 57., 58., 59.],
          [60., 61., 62., 63.]],

         [[64., 65., 66., 67.],
          [68., 69., 70., 71.],
          [72., 73., 74., 75.],
          [76., 77., 78., 79.]],

         [[80., 81., 82., 83.],
          [84., 85., 86., 87.],
          [88., 89., 90., 91.],
          [92., 93., 94., 95.]]]])
x_torch_trans is True same with x_einops_trans
x_einops_reshape is True same with x_check_reshape
image2patch.shape=torch.Size([2, 3, 4, 2, 2])
image2patch=
tensor([[[[[ 0.,  1.],
           [ 4.,  5.]],

          [[ 2.,  3.],
           [ 6.,  7.]],

          [[ 8.,  9.],
           [12., 13.]],

          [[10., 11.],
           [14., 15.]]],


         [[[16., 17.],
           [20., 21.]],

          [[18., 19.],
           [22., 23.]],

          [[24., 25.],
           [28., 29.]],

          [[26., 27.],
           [30., 31.]]],


         [[[32., 33.],
           [36., 37.]],

          [[34., 35.],
           [38., 39.]],

          [[40., 41.],
           [44., 45.]],

          [[42., 43.],
           [46., 47.]]]],



        [[[[48., 49.],
           [52., 53.]],

          [[50., 51.],
           [54., 55.]],

          [[56., 57.],
           [60., 61.]],

          [[58., 59.],
           [62., 63.]]],


         [[[64., 65.],
           [68., 69.]],

          [[66., 67.],
           [70., 71.]],

          [[72., 73.],
           [76., 77.]],

          [[74., 75.],
           [78., 79.]]],


         [[[80., 81.],
           [84., 85.]],

          [[82., 83.],
           [86., 87.]],

          [[88., 89.],
           [92., 93.]],

          [[90., 91.],
           [94., 95.]]]]])
image2patch2.shape=torch.Size([2, 12, 2, 2])
image2patch2=
tensor([[[[ 0.,  1.],
          [ 4.,  5.]],

         [[ 2.,  3.],
          [ 6.,  7.]],

         [[ 8.,  9.],
          [12., 13.]],

         [[10., 11.],
          [14., 15.]],

         [[16., 17.],
          [20., 21.]],

         [[18., 19.],
          [22., 23.]],

         [[24., 25.],
          [28., 29.]],

         [[26., 27.],
          [30., 31.]],

         [[32., 33.],
          [36., 37.]],

         [[34., 35.],
          [38., 39.]],

         [[40., 41.],
          [44., 45.]],

         [[42., 43.],
          [46., 47.]]],


        [[[48., 49.],
          [52., 53.]],

         [[50., 51.],
          [54., 55.]],

         [[56., 57.],
          [60., 61.]],

         [[58., 59.],
          [62., 63.]],

         [[64., 65.],
          [68., 69.]],

         [[66., 67.],
          [70., 71.]],

         [[72., 73.],
          [76., 77.]],

         [[74., 75.],
          [78., 79.]],

         [[80., 81.],
          [84., 85.]],

         [[82., 83.],
          [86., 87.]],

         [[88., 89.],
          [92., 93.]],

         [[90., 91.],
          [94., 95.]]]])
y=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
y_einops_mean=
tensor([[ 1.500,  5.500,  9.500],
        [13.500, 17.500, 21.500]])
y_tensor=
tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]]],


        [[[12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23]]]])
y_output=
tensor([[[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]]])
z_tensor=
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
z_tensor_1=
tensor([[[[ 0.],
          [ 1.],
          [ 2.]],

         [[ 3.],
          [ 4.],
          [ 5.]]],


        [[[ 6.],
          [ 7.],
          [ 8.]],

         [[ 9.],
          [10.],
          [11.]]]])
z_tensor_2=
tensor([[[[ 0.,  0.],
          [ 1.,  1.],
          [ 2.,  2.]],

         [[ 3.,  3.],
          [ 4.,  4.],
          [ 5.,  5.]]],


        [[[ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.]],

         [[ 9.,  9.],
          [10., 10.],
          [11., 11.]]]])
z_tensor_repeat=
tensor([[[ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.],
         [ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.]],

        [[ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.],
         [ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.]]])

pytorch_370">3. pytorch

在这里插入图片描述


http://www.niftyadmin.cn/n/5863753.html

相关文章

【沐风老师】3DMAX快速体块生成插件QuickBlocks使用方法详解

3DMAX快速体块生成插件QuickBlocks,一键在指定区域范围内快速生成(建筑)体块工具。对于大面积的配景楼制作,这款插件是最好的选择之一。QuickBlocks使用起来快捷灵活,不仅可以自定义生成的范围,而且还可以设…

复制所绑定元素文本的vue自定义指令

最近写了一个复制所绑定元素文本的vue自定义指令,给大家分享一下。 import { ElMessage } from element-plus// data-* 属性名 const dataCopyBtnTextAttribute data-copy-btn-text // 复制按钮的class,结合项目实际进行设置 const copyBtnClass icon…

Windows和Linux下,通过C++实现获取蓝牙版本号

在 C 中获取蓝牙版本号,不同的操作系统有不同的实现方式,下面分别介绍在 Windows 和 Linux 系统下的实现方法。 Windows 系统 在 Windows 系统中,可以使用 Windows API 来与蓝牙设备交互,获取蓝牙版本号。以下是一个示例代码&…

Vite 和 Webpack 的区别和选择

简介 Nuxt3 默认使用 Vite 作为构建工具,但也可以配置为使用 Webpack。‌ 关于两者的区别和详细结构化解析可以参考文章:vite和webpack底层逻辑差异 两者实例化案例可以参考文章 : Webpack和Vite插件的开发与使用_vite使用webpack-CSDN博客 简…

《Head First设计模式》读书笔记 —— 单件模式

文章目录 为什么需要单件模式单件模式典型实现剖析定义单件模式本节用例多线程带来的问题解决问题优化 Q&A总结 《Head First设计模式》读书笔记 相关代码: Vks-Feng/HeadFirstDesignPatternNotes: Head First设计模式读书笔记及相关代码 用来创建独一无二的&a…

【MATLAB例程】RSSI/PLE定位与卡尔曼滤波NLOS抑制算法,附完整代码

本 MATLAB 代码实现了基于接收信号强度指示(RSSI)和路径损耗模型(PLE)的定位算法,并结合卡尔曼滤波技术进行非视距(NLOS)干扰抑制。通过模拟真实运动轨迹,代码展示了如何在存在NLOS干扰的情况下进行有效的定位。订阅专栏后,可阅读完整代码,可直接运行 文章目录 运行结…

Python strip() 方法详解:用途、应用场景及示例解析(中英双语)

Python strip() 方法详解:用途、应用场景及示例解析 在 Python 处理字符串时,经常会遇到字符串前后存在多余的空格或特殊字符的问题。strip() 方法就是 Python 提供的一个强大工具,专门用于去除字符串两端的指定字符。本文将详细介绍 strip(…

Ubuntu 下 nginx-1.24.0 源码分析 - ngx_array_init 函数

ngx_array_init 定义在 src/core/ngx_array.h static ngx_inline ngx_int_t ngx_array_init(ngx_array_t *array, ngx_pool_t *pool, ngx_uint_t n, size_t size) {/** set "array->nelts" before "array->elts", otherwise MSVC thinks* that "…