如何绘制所创建函数的均方误差?

2024-06-17 09:31:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图画出不同θ值的均方误差

这是我的df,叫做提示:

tips = array([ 1.01,  1.66,  3.5 ,  3.31,  3.61,  4.71,  2.  ,  3.12,  1.96,
            3.23,  1.71,  5.  ,  1.57,  3.  ,  3.02,  3.92,  1.67,  3.71,
            3.5 ,  3.35,  4.08,  2.75,  2.23,  7.58,  3.18,  2.34,  2.  ,
            2.  ,  4.3 ,  3.  ,  1.45,  2.5 ,  3.  ,  2.45,  3.27,  3.6 ,
            2.  ,  3.07,  2.31,  5.  ,  2.24,  2.54,  3.06,  1.32,  5.6 ,
            3.  ,  5.  ,  6.  ,  2.05,  3.  ,  2.5 ,  2.6 ,  5.2 ,  1.56,
            4.34,  3.51,  3.  ,  1.5 ,  1.76,  6.73,  3.21,  2.  ,  1.98,
            3.76,  2.64,  3.15,  2.47,  1.  ,  2.01,  2.09,  1.97,  3.  ,
            3.14,  5.  ,  2.2 ,  1.25,  3.08,  4.  ,  3.  ,  2.71,  3.  ,
            3.4 ,  1.83,  5.  ,  2.03,  5.17,  2.  ,  4.  ,  5.85,  3.  ,
            3.  ,  3.5 ,  1.  ,  4.3 ,  3.25,  4.73,  4.  ,  1.5 ,  3.  ,
            1.5 ,  2.5 ,  3.  ,  2.5 ,  3.48,  4.08,  1.64,  4.06,  4.29,
            3.76,  4.  ,  3.  ,  1.  ,  4.  ,  2.55,  4.  ,  3.5 ,  5.07,
            1.5 ,  1.8 ,  2.92,  2.31,  1.68,  2.5 ,  2.  ,  2.52,  4.2 ,
            1.48,  2.  ,  2.  ,  2.18,  1.5 ,  2.83,  1.5 ,  2.  ,  3.25,
            1.25,  2.  ,  2.  ,  2.  ,  2.75,  3.5 ,  6.7 ,  5.  ,  5.  ,
            2.3 ,  1.5 ,  1.36,  1.63,  1.73,  2.  ,  2.5 ,  2.  ,  2.74,
            2.  ,  2.  ,  5.14,  5.  ,  3.75,  2.61,  2.  ,  3.5 ,  2.5 ,
            2.  ,  2.  ,  3.  ,  3.48,  2.24,  4.5 ,  1.61,  2.  , 10.  ,
            3.16,  5.15,  3.18,  4.  ,  3.11,  2.  ,  2.  ,  4.  ,  3.55,
            3.68,  5.65,  3.5 ,  6.5 ,  3.  ,  5.  ,  3.5 ,  2.  ,  3.5 ,
            4.  ,  1.5 ,  4.19,  2.56,  2.02,  4.  ,  1.44,  2.  ,  5.  ,
            2.  ,  2.  ,  4.  ,  2.01,  2.  ,  2.5 ,  4.  ,  3.23,  3.41,
            3.  ,  2.03,  2.23,  2.  ,  5.16,  9.  ,  2.5 ,  6.5 ,  1.1 ,
            3.  ,  1.5 ,  1.44,  3.09,  2.2 ,  3.48,  1.92,  3.  ,  1.58,
            2.5 ,  2.  ,  3.  ,  2.72,  2.88,  2.  ,  3.  ,  3.39,  1.47,
            3.  ,  1.25,  1.  ,  1.17,  4.67,  5.92,  2.  ,  2.  ,  1.75,
            3.  ])

这是我的平方损失函数:

def squared_loss(y_obs, theta):
"""
Calculate the squared loss of the observed data and a summary statistic.

Parameters
------------
y_obs: an observed value
theta : some constant representing a summary statistic

Returns
------------
The squared loss between the observation and the summary statistic.
"""
return (y_obs - theta) ** 2

这是我的均方误差函数:

def mean_squared_error(theta, data):
    
    return sum(squared_loss(data, theta)) / len(data)

这就是问题所在:在下面的单元格中,绘制不同θ值的均方误差。注意,θu值是给定的。确保在绘图上标记轴。记住使用我们前面定义的tips变量

theta_values = np.linspace(0, 6, 100)

plt.plot(mean_squared_error(theta_values, tips))

这给了我以下信息: ValueError:操作数无法与形状(244、)(100、)一起广播。

如果图是正确的,则观察到的最小点应为3。有人知道我能做些什么让我的阴谋出现吗?我在想类似于for循环的东西,但不是很确定

谢谢

编辑:尝试输入θu值,但θu值应为给定值且不可触摸

enter image description here


Tags: andthe函数datareturndefsummarystatistic
1条回答
网友
1楼 · 发布于 2024-06-17 09:31:43

出现错误是因为您试图使用均方误差函数广播两个不同大小的数组。tips df的大小为244,在创建theta values数组时,将其设置为100个0到6之间的等距值,结果大小为100

利用

theta_values = np.linspace(0, 6, 244)

您将创建一个带有244个值的theta_values变量,该变量将正确映射到tips数据帧,并且在计算MSE时不会引起问题

编辑: 为了适应OP更新,假设图是误差平方(SE)与θ之比。要计算的整个代码如下所示;以及输出图。提醒:绘制的是平方误差(即y_true(假定为tips)和y_pred(假定为θ)平方之间的误差)与θ的关系。输出似乎显示3左右的波动较小(如OP所示),但OP需要更多澄清

import numpy as np
import matplotlib.pyplot as plt

tips = np.array([ 1.01,  1.66,  3.5 ,  3.31,  3.61,  4.71,  2.  ,  3.12,  1.96,
               3.23,  1.71,  5.  ,  1.57,  3.  ,  3.02,  3.92,  1.67,  3.71,
               3.5 ,  3.35,  4.08,  2.75,  2.23,  7.58,  3.18,  2.34,  2.  ,
               2.  ,  4.3 ,  3.  ,  1.45,  2.5 ,  3.  ,  2.45,  3.27,  3.6 ,
               2.  ,  3.07,  2.31,  5.  ,  2.24,  2.54,  3.06,  1.32,  5.6 ,
               3.  ,  5.  ,  6.  ,  2.05,  3.  ,  2.5 ,  2.6 ,  5.2 ,  1.56,
               4.34,  3.51,  3.  ,  1.5 ,  1.76,  6.73,  3.21,  2.  ,  1.98,
               3.76,  2.64,  3.15,  2.47,  1.  ,  2.01,  2.09,  1.97,  3.  ,
               3.14,  5.  ,  2.2 ,  1.25,  3.08,  4.  ,  3.  ,  2.71,  3.  ,
               3.4 ,  1.83,  5.  ,  2.03,  5.17,  2.  ,  4.  ,  5.85,  3.  ,
               3.  ,  3.5 ,  1.  ,  4.3 ,  3.25,  4.73,  4.  ,  1.5 ,  3.  ,
               1.5 ,  2.5 ,  3.  ,  2.5 ,  3.48,  4.08,  1.64,  4.06,  4.29,
               3.76,  4.  ,  3.  ,  1.  ,  4.  ,  2.55,  4.  ,  3.5 ,  5.07,
               1.5 ,  1.8 ,  2.92,  2.31,  1.68,  2.5 ,  2.  ,  2.52,  4.2 ,
               1.48,  2.  ,  2.  ,  2.18,  1.5 ,  2.83,  1.5 ,  2.  ,  3.25,
               1.25,  2.  ,  2.  ,  2.  ,  2.75,  3.5 ,  6.7 ,  5.  ,  5.  ,
               2.3 ,  1.5 ,  1.36,  1.63,  1.73,  2.  ,  2.5 ,  2.  ,  2.74,
               2.  ,  2.  ,  5.14,  5.  ,  3.75,  2.61,  2.  ,  3.5 ,  2.5 ,
               2.  ,  2.  ,  3.  ,  3.48,  2.24,  4.5 ,  1.61,  2.  , 10.  ,
               3.16,  5.15,  3.18,  4.  ,  3.11,  2.  ,  2.  ,  4.  ,  3.55,
               3.68,  5.65,  3.5 ,  6.5 ,  3.  ,  5.  ,  3.5 ,  2.  ,  3.5 ,
               4.  ,  1.5 ,  4.19,  2.56,  2.02,  4.  ,  1.44,  2.  ,  5.  ,
               2.  ,  2.  ,  4.  ,  2.01,  2.  ,  2.5 ,  4.  ,  3.23,  3.41,
               3.  ,  2.03,  2.23,  2.  ,  5.16,  9.  ,  2.5 ,  6.5 ,  1.1 ,
               3.  ,  1.5 ,  1.44,  3.09,  2.2 ,  3.48,  1.92,  3.  ,  1.58,
               2.5 ,  2.  ,  3.  ,  2.72,  2.88,  2.  ,  3.  ,  3.39,  1.47,
               3.  ,  1.25,  1.  ,  1.17,  4.67,  5.92,  2.  ,  2.  ,  1.75,
               3.  ])

theta_values = np.linspace(0, 6, 244)


def sqr_err(y_true, y_pred):
    """

    :param y_true: true values of y
    :param y_pred: predicted values of y
    :return: array of lenght original data containing mean squared error for each predictions
    """
    if len(y_true) != len(y_pred):
        raise IndexError("Mismathced array sizes, you inputted arrays with sizes {} and {}".format(len(y_true),
                                                                                                  len(y_pred)))
    else:
        length = len(y_true)

    sqrerror_out = [(y_pred[i]-y_true[i])**2 for i in range(length)]

    return np.array(sqrerror_out)


theta_value = np.linspace(0, 6, 244)

Squared_error = sqr_err(tips, theta_value)

plt.figure()
plt.plot(theta_values, Squared_error)
plt.xlabel('Theta Values')
plt.ylabel('Squared Error')
plt.show()

Squared error plot of theta and tips

相关问题 更多 >