JAX卷积的雅可比矩阵计算

1 投票
1 回答
39 浏览
提问于 2025-04-14 15:33

我正在使用JAX来进行卷积操作。

def gaussian_kernel(size: int, std: float):
    """Generates a 2D Gaussian kernel."""
    x, y = jnp.mgrid[-size:size+1, -size:size+1]
    g = jnp.exp(-(x**2 + y**2) / (2 * std**2))
    return g / g.sum()
    
def gaussian_blur(image, kernel_size=5, sigma=1.0):
    """Applies Gaussian blur to a 2D image."""
    kernel = gaussian_kernel(kernel_size, sigma)
    blurred_image = convolve2d(image, kernel, mode='same')
    return blurred_image 

基本上,就是一个普通的模糊效果。

不过,我对卷积的导数在输入像素方面是什么样子并不太明白。

也就是说,改变输入像素y会对输出像素x产生什么影响。

我该如何定义这个呢?我该如何从JAX中提取这个信息?我甚至不知道从哪里开始!

我希望能够用JAX提取输出像素相对于输入像素的梯度。

1 个回答

0

你可以使用 jax.jacobian 来计算雅可比矩阵:

image_out = gaussian_blur(image)
image_jac = jax.jacobian(gaussian_blur)(image)

对于一个形状为 (M, N) 的输入 image,输出的 image_out 也会是形状 (M, N),而 image_jac 的形状则是 (M, N, M, N)

这里的值 image_jac[i, j, k, l] 表示 image_out[i, j]image[k, l] 的偏导数。

撰写回答