Implementation of EM Algorithm on Gaussian Mixure Model¶

Huarui Zhou

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.stats import multivariate_normal as mvn

1D GMM¶

In [3]:
samples1d = np.concatenate([np.random.normal(0, 1.2, 2000), np.random.normal(4,0.8 , 5000), np.random.normal(8, 1, 3000)])
# Plot a histogram of the generated samples
plt.hist(samples1d, bins=30, density=True, color='blue', label='Generated Samples')
plt.title('Histogram of Normal Distribution Samples')
plt.xlabel('Values')
plt.ylabel('Probability Density')
plt.legend()
plt.show()
In [4]:
def GMM_1d(data,K,max_iter,error):
    N = len(data)
    w_list = np.ones(K) / K
    mean_list = np.random.choice(data, K, replace=False)
    var_list = np.ones(K)
        
    for iter_num in range(max_iter):
        r_list = np.zeros((N, K))  
        for j in range(K):
            r_list[:,j] = w_list[j] * norm.pdf(data,mean_list[j], np.sqrt(var_list[j]))
            
        r_list /= np.sum(r_list, axis=1, keepdims=True)  
        #axis = 1 will sum all items in each row and 
        #keepdims=True will make the result be like [[sum1],[sum2],...], otherwise it will be [sum1,sum2,...]
        #Dividing a matrix by [[sum1],[sum2],...] will divide all items in one row by one value, it will be the column without bracket.
        
        N_k = np.sum(r_list, axis=0)
        
        mean_list_updated = np.sum(r_list * data[:,np.newaxis],axis=0)/ N_k # * performs element-wise multiplication
        var_list_updated = np.sum(r_list*(data[:,np.newaxis]-mean_list_updated)**2,axis=0) / N_k
        w_list = N_k / N

        if max(np.abs(mean_list_updated-mean_list)) < error:
            break
        else:
            mean_list = mean_list_updated
    return w_list, mean_list_updated, var_list_updated       
    
In [5]:
#define a function a calculate the log likelihood
def log_likelihood1d(data, w_list, mean_list,var_list):
    Sum1 = 0
    for i in range(len(data)):
        Sum2 = 0
        for j in range(len(w_list)):
            Sum2 += w_list[j]*norm.pdf(data[i],mean_list[j], np.sqrt(var_list[j]))
        Sum1 += np.log(Sum2)
    return Sum1  
In [6]:
#Tune the number (K) of normal distributions in GMM 
log_lkh = []
for k in [1,2,3,4,5]:
    w_list, mean_list, var_list = GMM_1d(data = samples1d,K=k,max_iter = 1000,error = 1e-5)
    log_lkh.append(-log_likelihood1d(samples1d, w_list, mean_list,var_list))
plt.plot([1,2,3,4,5],log_lkh)
plt.xlabel('K')
plt.ylabel('Negative Log Likelihood')
plt.show()
In [7]:
def plot_gmm(data,w_list, mean_list, var_list):
    plt.hist(data,bins=50,density=True,color='blue',label='Data Histogram')
    x_range = np.linspace(np.min(data), np.max(data), 1000)
    y = np.zeros_like(x_range)
    for i in range(len(w_list)):
        y += w_list[i]*norm.pdf(x_range,mean_list[i],np.sqrt(var_list[i]))
    plt.plot(x_range,y,color='red',label='GMM fit curve')
    plt.title('1D Gaussian Mixture Model')
    plt.xlabel('Data Values')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.show()
In [9]:
w_list, mean_list, var_list = GMM_1d(data = samples1d,K=3,max_iter = 1000,error = 1e-5)
plot_gmm(samples1d,w_list, mean_list, var_list)

2D GMM¶

In [17]:
# Generate 2D normal samples
samples1 = np.random.multivariate_normal([0, 0], [[0.3, 0], [0, 0.3]], 300)
samples2 = np.random.multivariate_normal([3, 8], [[0.8, 0.5], [0.5, 0.8]], 200)
samples3 = np.random.multivariate_normal([0, 3], [[0.4, 0], [0, 1]], 100)
samples4 = np.random.multivariate_normal([4, 2], [[1, 0], [0, 1]], 200)
samples5 = np.random.multivariate_normal([5, 6], [[0.2, 0], [0, 0.2]], 200)
samples2d = np.concatenate([samples1,samples2,samples3,samples4,samples5])

# Plot the generated samples
plt.scatter(samples2d[:, 0], samples2d[:, 1], alpha=0.5)
plt.title('2D Normal Samples')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.show()
In [18]:
#Define a function to perform GMM 
def GMM_2d(data, K, max_iter, error,random_state=42):
    N = len(data)
    r_list = np.empty((N, K))
    w_list = np.ones(K) / K
    indices = np.random.choice(N, K, replace=False)
    mean_list = data[indices, :]   
    cov_list = [1 * np.eye(2) for i in range(K)]
    
    for iter_num in range(max_iter):
        for j in range(K):
            r_list[:,j] = w_list[j] * mvn(mean_list[j], cov_list[j]).pdf(data)
        r_list /= np.sum(r_list,axis=1,keepdims=True)
        N_k = np.sum(r_list,axis=0)
        
        mean_list_updated = r_list.T @ data / N_k[:,np.newaxis]
        cov_list = [(data-mean_list_updated[k]).T@((data-mean_list_updated[k])*r_list[:,[k]])/ N_k[k] for k in range(K)]        
        w_list = N_k / N
        if np.max(np.abs(mean_list_updated-mean_list)) < error:
            break
        else:
            mean_list = mean_list_updated
    return w_list, mean_list_updated, cov_list    
In [19]:
#define a function a calculate the log likelihood
def log_likelihood2d(data, w_list,mean_list,cov_list):
    rv = []
    for j in range(len(w_list)):
        rv.append(mvn(mean_list[j], cov_list[j]))
    
    Sum1 = 0
    for i in range(len(data)):
        Sum2 = 0
        for j in range(len(w_list)):
            Sum2 += w_list[j]*rv[j].pdf(data[i])
        Sum1 += np.log(Sum2)
    return Sum1 
In [20]:
#Tune the number (K) of normal distributions in GMM 
log_lkh = [ ]
for k in [1,2,3,4,5,6]:
    w_list, mean_list, cov_list = GMM_2d(samples2d,k, 1000,1e-5)
    log_lkh.append(-log_likelihood2d(samples2d, w_list, mean_list,cov_list))
plt.plot([1,2,3,4,5,6],log_lkh) 
plt.xlabel('K')
plt.ylabel('Negative Log Likelihood')
plt.show()
In [21]:
# define a function to calculate log g(x)
def log_density(point,w_list, mean_list, cov_list):
    K = len(w_list)
    res = sum([w_list[k] * mvn(mean_list[k], cov_list[k]).pdf(point) for k in range(K)])
    return np.log(res)

def GMM_2d_plot(data,w_list, mean_list, cov_list):
    # Create a grid for plotting
    x, y = np.meshgrid(np.linspace(min(data[:,0]), max(data[:,0]), 100), np.linspace(min(data[:,1]), max(data[:,1]), 100))
    xy = np.column_stack([x.ravel(), y.ravel()])

    # Evaluate the GMM on the grid
    z = -log_density(xy,w_list, mean_list, cov_list)
    z = z.reshape(x.shape)

    # Plot the density estimation
    plt.figure(figsize=(8, 8))
    plt.scatter(data[:, 0], data[:, 1], alpha=0.5, label='Data points')
    plt.contour(x, y, z, levels=10, cmap='viridis', linewidths=2)
    plt.title('Density Estimation for Gaussian Mixture Model')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend()
    plt.show()
In [23]:
w_list, mean_list, cov_list = GMM_2d(samples2d,5, 1000,1e-6)
GMM_2d_plot(samples2d,w_list, mean_list, cov_list)