小马的世界

【机器学习】二维高斯函数的理解

2021-11-30 · 7 min read
机器学习 数学

看了好久,终于弄懂了二维高斯函数。谈谈我的简单理解。

高斯函数也就是正态分布的密度函数。一维高斯函数我们在概统中学习过,其方程为

y=ae(xμ)22σ2y=ae^{-\frac{(x-\mu)^2}{2\sigma ^2}}

或者写为

y=aexp((xμ)22σ2)y=a\exp (-\frac{(x-\mu)^2}{2\sigma ^2})

其中μ\mu代表中心(均值),σ\sigma表示分布的幅度(标准差),aa表示高度。
若要使得对xx求积分的值为1,则

a=1(2πσ2)1/2a=\frac{1}{(2\pi \sigma ^2)^{1/2}}

绘制图形的代码如下:

def gauss(mu, sigma, a):
    return a * np.exp(-(x - mu)**2 / (2 * sigma**2))

x = np.linspace(-4, 4, 100)
plt.figure(figsize=(4, 4))
plt.plot(x, gauss(0, 1, 1), 'black', linewidth=3)
plt.plot(x, gauss(2, 2, 0.5), 'gray', linewidth=3)

plt.ylim(-.5, 1.5)
plt.xlim(-4, 4)
plt.grid(True)
plt.show()


黑色线为μ=0,σ=1,a=1\mu=0, \sigma=1, a=1, 灰色线为μ=2,σ=2,a=0.5\mu=2, \sigma=2, a=0.5


到了二维,输入的就不只是xx了,而是x=[x0,x1]T\boldsymbol{x} = [x_0, x_1]^T
而高斯公式就变成了

y=aexp{12(xμ)TΣ1(xμ)}y=a\cdot \exp\{-\frac{1}{2}(\boldsymbol{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\}

对比一维

y=aexp(12(xμ)σ2(xμ))y=a\cdot \exp (-\frac{1}{2}(x-\mu)\sigma^{-2}(x-\mu))

依然还是用三个变量用来控制二维高斯函数的形状:aa还是表示高度,μ\boldsymbol{\mu}表示均值向量(中心向量),表示函数分布的中心:

μ=[μ0 μ1]T\boldsymbol{\mu}= [\mu _0 \ \mu _1]^T

Σ\boldsymbol{\Sigma}是协方差矩阵,是一个如下所示的2×22\times 2矩阵:

Σ=[σ02σ01σ01σ12]\boldsymbol{\Sigma}=\begin{bmatrix} \sigma _0^2&\sigma _{01} \\ \sigma _{01} & \sigma _1^2 \end{bmatrix}

Σ\boldsymbol{\Sigma}中,σ02\sigma _0^2σ12\sigma _1^2用来调整x0x_0x1x_1方向分布的幅度(可以理解为两个一维分布中分别的σ\sigma)。$ \sigma _{01} 用于调整函数分布方向上的斜率,如果是正数,那么函数图形是右上左下方向↙️的椭圆;如果是负数,则是左上右下方向↘️的椭圆(x_0为横轴x_1$为纵轴的情况)。

简单起见,我们暂时设μ=[μ0 μ1]T=[0 0]T\boldsymbol{\mu}= [\mu _0 \ \mu _1]^T=[0 \ 0]^T,然后计算(xμ)TΣ1(xμ)(\boldsymbol{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\boldsymbol{x}-\boldsymbol{\mu})。则可以发现,这个式子可以化成由x0x_0x1x_1组成的二次型。

(xμ)TΣ1(xμ)=[x0 x1]1σ02σ12σ012[σ02σ01σ01σ12][x0x1]=1σ02σ12σ012(σ12x022σ01x0x1+σ02x12)(\boldsymbol{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\boldsymbol{x}-\boldsymbol{\mu}) \\ =[x_0\ x_1]\cdot \frac{1}{\sigma _0^2 \sigma _1^2- \sigma _{01}^2}\begin{bmatrix} \sigma _0^2&\sigma _{01} \\ \sigma _{01} & \sigma _1^2 \end{bmatrix}\begin{bmatrix} x_0 \\ x_1 \end{bmatrix}\\ =\frac{1}{\sigma _0^2 \sigma _1^2- \sigma _{01}^2}(\sigma _1^2 x _0^2 - 2\sigma _{01}x_0x_1 + \sigma _0^2x_1^2)

若要使得积分的值为1,则

a=12π1Σ1/2a=\frac{1}{2\pi}\frac{1}{|\boldsymbol{\Sigma}|^{1/2}}

其中Σ=det(Σ)=σ02σ12σ012|\boldsymbol{\Sigma}|=\det (\boldsymbol{\Sigma})=\sigma _0^2 \sigma _1^2-\sigma_{01}^2
二维高斯函数的代码如下:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
%matplotlib inline

# 高斯函数 -----------------------------
def gauss(x, mu, sigma):
    N, D = x.shape
    c1 = 1 / (2 * np.pi)**(D / 2)
    c2 = 1 / (np.linalg.det(sigma)**(1 / 2))
    inv_sigma = np.linalg.inv(sigma)
    c3 = x - mu
    c4 = np.dot(c3, inv_sigma)
    c5 = np.zeros(N)
    for d in range(D):
        c5 = c5 + c4[:, d] * c3[:, d]
    p = c1 * c2 * np.exp(-c5 / 2)
    return p

绘图的代码如下所示:

X_range0=[-3, 3]
X_range1=[-3, 3]

# 显示等高线 --------------------------------
def show_contour_gauss(mu, sig):
    xn = 40  # 等高线的分辨率
    x0 = np.linspace(X_range0[0], X_range0[1], xn)
    x1 = np.linspace(X_range1[0], X_range1[1], xn)
    xx0, xx1 = np.meshgrid(x0, x1)
    x = np.c_[np.reshape(xx0, [xn * xn, 1]), np.reshape(xx1, [xn * xn, 1])]
    f = gauss(x, mu, sig)
    f = f.reshape(xn, xn)
    f = f.T
    cont = plt.contour(xx0, xx1, f, 15, colors='k')
    plt.grid(True)
    
# 三维图形 ----------------------------------
def show3d_gauss(ax, mu, sig):
    xn = 40  # 等高线的分辨率
    x0 = np.linspace(X_range0[0], X_range0[1], xn)
    x1 = np.linspace(X_range1[0], X_range1[1], xn)
    xx0, xx1 = np.meshgrid(x0, x1)
    x = np.c_[np.reshape(xx0, [xn * xn, 1]), np.reshape(xx1, [xn * xn, 1])]
    f = gauss(x, mu, sig)
    f = f.reshape(xn, xn)
    f = f.T
    ax.plot_surface(xx0, xx1, f, 
                    rstride=2, cstride=2, alpha=0.3, 
                    color='blue', edgecolor='black')
    
# 主处理 -----------------------------------
mu = np.array([1, 0.5])            # (A)
sigma = np.array([[2, 1], [1, 1]]) # (B)
Fig = plt.figure(1, figsize=(7, 3))
Fig.add_subplot(1, 2, 1)
show_contour_gauss(mu, sigma)
plt.xlim(X_range0)
plt.ylim(X_range1)
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
Ax = Fig.add_subplot(1, 2, 2, projection='3d')
show3d_gauss(Ax, mu, sigma)
Ax.set_zticks([0.05, 0.10])
Ax.set_xlabel('$x_0$', fontsize=14)
Ax.set_ylabel('$x_1$', fontsize=14)
Ax.view_init(40, -100)
plt.show()