一、注意
首先我们一定要注意,执行 broadcast 的前提在于,两个 ndarray 执行的是 element-wise(按位加,按位减) 的运算,而不是矩阵乘法的运算,矩阵乘法运算时需要维度之间严格匹配。
例子如下:(当矩阵乘法的时候)
import numpy as np
A = np.zeros((2,4))
B = np.zeros((3,4))
np.dot(A,B)
报的错误是aligned:如下、
而提示的错误如果是 broadcast的问题的话,一定是两个 ndarray 执行的是 element-wise(按位加,按位减) 的运算维度不匹配,例子如下:
import numpy as np
A = np.zeros((2,4))
B = np.zeros((3,4))
C = A*B
报的错误如下:
二、broadcast的简单例子
举一个简单的例子,实现对一个1-d array的每一个元素乘以2:
>>> a = np.array([1., 2., 3.])
>>> b = np.array([2., 2., 2.])
>>> a*b
array([2., 4., 6.])
broadcast的做法是:
>>> a = np.array([1., 2., 3.])
>>> b = 2.
>>> a*b
array([2., 4., 6.])
三、Broadcast(广播)的规则
- All input arrays with ndim smaller than the input array of largest ndim, have 1’s prepended to their shapes.
- The size in each dimension of the output shape is the maximum of all the input sizes in that dimension.
- An input can be used in the calculation if its size in a particular dimension either matches the output size in that dimension, or has value exactly 1.
- If an input has a dimension size of 1 in its shape, the first data entry in that dimension will be used for all calculations along that dimension. In other words, the stepping machinery of the ufunc will simply not step along that dimension (the stride will be 0 for that dimension).
翻译如下:
- 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
- 输出数组的shape是输入数组shape的各个轴上的最大值
- 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
- 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值
来看更为一般的broadcasting rules:
当操作两个array时,numpy会逐个比较它们的shape(构成的元组tuple),只有在下述情况下,两arrays才算兼容:
- 相等
- 其中一个为1,(进而可进行拷贝拓展已至,shape匹配)
下面通过实际例子来解释说明上述的四条规则:(下面例子均来自于numpy 中的 broadcasting(广播)机制)
numpy 中的 broadcasting(广播)机制blog.csdn.net
举例说明:
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5
A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4
A (2d array): 15 x 3 x 5
B (1d array): 15 x 1 x 5
Result (2d array): 15 x 3 x 5
再来看一些不能进行broadcast的例子:
A (1d array): 3
B (1d array): 4 # 最后一维(trailing dimension)不匹配
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3(倒数第二维不匹配)
我们再来看一些具体的应用:
>>> x = np.arange(4)
>> xx = x.reshape(4, 1)
>> y = np.ones(5)
>> z = np.ones((3, 4))
>>> x.shape
(4,)
>>> y.shape
(5,)
>>> x+y
ValueError: operands could not be broadcast together with shapes (4,) (5,)
>>> xx.shape
(4, 1)
>>> y.shape
(5,)
>>> (xx+y).shape
(4, 5)
>>> xx + y
array([[ 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4.]])
当执行xx+y时,numpy是如何进行copy扩展的呢?
xx (2d array): 4 x 1
y (1d array): 5
Result (2d array): 4 x 5
也即对xx重复5列,对y重复4行
# 对xx重复5列
# 等价于np.dot(xx, np.ones((1, 4)))
array([[ 0., 0., 0., 0.],
[ 1., 1., 1., 1.],
[ 2., 2., 2., 2.],
[ 3., 3., 3., 3.]])
# 对y重复4行,
array([[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.],
[ 1., 1., 1., 1., 1.]])
最后来个官网文档的图例收尾: