0 刪失多元正態的均值和方差是多少?
讓在. 什麼是均值矩陣和協方差矩陣(按元素計算的最大值)?
這是因為,如果我們在深度網絡中使用 ReLU 激活函數,並通過 CLT 假設給定層的輸入近似正常,那麼這就是輸出的分佈。
S 羅森鮑姆 (1961)。截斷二元正態分佈的矩。JRSS B,第 23 卷,第 405-408 頁。(傑斯特)
並考慮截斷事件. 具體來說,我們將使用以下三個結果,即他的 (1)、(3) 和 (5)。首先,定義以下內容:
考慮 (1) 和 (3) 的特殊情況也是有用的,即一維截斷:
這是和什麼時候,. 現在,使用 (*),我們得到
並同時使用 (*) 和 (**) 產量
尋找, 我們會需要
下面是一些計算矩的 Python 代碼:
import numpy as np from scipy import stats def relu_mvn_mean_cov(mu, Sigma): mu = np.asarray(mu, dtype=float) Sigma = np.asarray(Sigma, dtype=float) d, = mu.shape assert Sigma.shape == (d, d) x = (slice(None), np.newaxis) y = (np.newaxis, slice(None)) sigma2s = np.diagonal(Sigma) sigmas = np.sqrt(sigma2s) rhos = Sigma / sigmas[x] / sigmas[y] prob = np.empty((d, d)) # prob[i, j] = Pr(X_i > 0, X_j > 0) zero = np.zeros(d) for i in range(d): prob[i, i] = np.nan for j in range(i + 1, d): # Pr(X > 0) = Pr(-X < 0); X ~ N(mu, S) => -X ~ N(-mu, S) s = [i, j] prob[i, j] = prob[j, i] = stats.multivariate_normal.cdf( zero[s], mean=-mu[s], cov=Sigma[np.ix_(s, s)]) mu_sigs = mu / sigmas Q = stats.norm.cdf(mu_sigs) q = stats.norm.pdf(mu_sigs) mean = Q * mu + q * sigmas # rho_cs is sqrt(1 - rhos**2); but don't calculate diagonal, because # it'll just be zero and we're dividing by it (but not using result) # use inf instead of nan; stats.norm.cdf doesn't like nan inputs rho_cs = 1 - rhos**2 np.fill_diagonal(rho_cs, np.inf) np.sqrt(rho_cs, out=rho_cs) R = stats.norm.cdf((mu_sigs[y] - rhos * mu_sigs[x]) / rho_cs) mu_sigs_sq = mu_sigs ** 2 r_num = mu_sigs_sq[x] + mu_sigs_sq[y] - 2 * rhos * mu_sigs[x] * mu_sigs[y] np.fill_diagonal(r_num, 1) # don't want slightly negative numerator here r = rho_cs / np.sqrt(2 * np.pi) * stats.norm.pdf(np.sqrt(r_num) / rho_cs) bit = mu[y] * sigmas[x] * q[x] * R cov = ( (mu[x] * mu[y] + Sigma) * prob + bit + bit.T + sigmas[x] * sigmas[y] * r - mean[x] * mean[y]) cov[range(d), range(d)] = ( Q * (1 - Q) * mu**2 + (1 - 2 * Q) * q * mu * sigmas + (Q - q**2) * sigma2s) return mean, cov
np.random.seed(12) d = 4 mu = np.random.randn(d) L = np.random.randn(d, d) Sigma = dist = stats.multivariate_normal(mu, Sigma) mn, cov = relu_mvn_mean_cov(mu, Sigma) samps = dist.rvs(10**7) mn_est = samps.mean(axis=0) cov_est = np.cov(samps, rowvar=False) print(np.max(np.abs(mn - mn_est)), np.max(np.abs(cov - cov_est)))
0.000572145310512 0.00298692620286