10.解析解方法推导线性回归——不容小觑的线性回归算法

引言

线性回归是许多复杂机器学习模型的基础。作为一种基本的机器学习方法,线性回归提供了清晰的思路和工具,通过理解其推导过程,可以更好地掌握机器学习的基本原理和模型设计。

通过阅读本篇博客,你可以:

1.学会如何用解析解的方法推导线性回归的最优解

2.了解如何判定损失函数是凸函数或非凸

一、解析解的推导

通过上一篇9.深入线性回归推导出MSE——不容小觑的线性回归算法-CSDN博客的讲解,我们已经得到了线性回归的损失函数形式,也明确了目标就是最小化损失函数,那么问题就变成了 \theta 什么时候可以使得损失函数最小。

1.最小二乘形式变化

我们已知损失函数公式为 :

J(\theta ) = \frac{1}{2}\sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})^{2} 

我们将损失函数变化一个形式,变为以下:

\frac{1}{2}(X\theta - y)^{T}(X\theta -y)

其中 X 为变量集,即所有 x_{i} 样本行程的 m 行 n 列的样本矩阵\theta 即我们要求的最优解,它是一个 m 行 1 列的矩阵。这个公式是如何变化而来的呢?

首先原损失函数中的 h_{\theta } 是线性回归模型中的预测函数,h_{\theta }x_{i} 用来表示预测值 \hat{y_{i}} 。所以我们可以得出以下结论:

\hat{y_{i}} = h_{\theta }x_{i} = x_{i}\theta

\hat{y} = h_{\theta }X = X\theta

得到这个结论之后,我们回归到公式本身:

J(\theta ) = \frac{1}{2}\sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})^{2} 

\Rightarrow J(\theta ) = \frac{1}{2} \sum_{m}^{i=1}(h_{\theta }x_{i}-y_{i})(h_{\theta }x_{i}-y_{i})

将上述公式代入,又由于矩阵的性质,我们需要将其中一项转置,这里就相当于一个长度为 m 的向量乘以它自己,说白了就是对应位置相乘相加。

所以我们的公式变为:

J(\theta ) = \frac{1}{2}(X\theta - y)^{T}(X\theta - y)

由矩阵运算的基本性质

可继续推出公式:

J(\theta) = \frac{1}{2}((X\theta)^{T} - y^{T} )(X\theta - y)

\Rightarrow J(\theta) = \frac{1}{2} (\theta^{T}X^{T} - y^{T})(X\theta - y)

最终,我们得到:

J(\theta) = \frac{1}{2}(\theta^{T}X^{T}X\theta - \theta^{T}X^{T}y-y^{T}X\theta+y^{T}y)

2.推导出模型的解析解形式

假使我们开着小车,从下图中寻找最优解。为了便于理解,我们假设存在横轴表示 \theta ,存在纵轴表示 loss损失,曲线是 loss function

我们把最小二乘看成是一个函数曲线,最优解一定是驻点中某个极小值(驻点顾名思义就是小车可以停驻的点)。从图中我们可以看出,驻点的特定是梯度全为0(梯度:函数在某点上的切线的斜率)。

所以要求出 \theta 的解析解形式,我们就可以通过把函数的一阶导函数推导出来,再使其的值为0以求出 \theta 。依据以下求导公式:

我们能将公式进行推导:

{J}'(\theta) = \frac{1}{2}\left [ {(\theta^{T}X^{T}X\theta)}' - {(\theta^{T}X^{T}y)}'-{(y^{T}X\theta)}' + {(y^{T}y)}'\right ]

由于 X 和 y 是已知的,\theta 是我们要求的答案。所以和 \theta 没关系的部分在求导时可以忽略不计,继续推导为以下公式:

\Rightarrow {J}'(\theta) = \frac{1}{2}[2X^{T}X\theta-X^{T}y-(y^{T}X)^T]

\Rightarrow {J}'(\theta) = \frac{1}{2}[2X^{T}X\theta-2X^{T}y]

\Rightarrow {J}'(\theta) = X^{T}X\theta - X^{T}y

然后我们设置导函数为0,去进一步解出来驻点对应的 \theta 值为多少:

0 = X^{T}X\theta - X^{T}y

\Rightarrow X^{T}X\theta = X^{T}y

由于矩阵与逆矩阵相乘可以得到单位矩阵,所以我们最终可以求出 \theta 的解析解形式(解析解为方程的解析式,是方程的精确解,能在任意精度下满足方程):

\theta = (X^{T}X)^{-1}X^{T}y

这样,我们有数据集 X ,y 时,就可以将数据代入上面解析解公式,去直接求出对应的 \theta 值了。比如我们可以设想 X 为 m 行 n 列的矩阵,y 为 m 行 1 列的列向量。X^{T} 是 n 行 m 列的,所以 X^{T}X 就是 n 行 n 列的矩阵。又因为矩阵求逆形状不变,再次乘以 X^{T} 后变为 n 行 m 列的矩阵。最后乘以 y,结果  \theta 就是 n 行 1 列的列向量!

二、判断损失函数是否为凸函数

对于求解最优解而言,判断一个损失函数是否为凸函数是极其重要的。如果一个损失函数是凸函数,那么局部最优解即为全局最优解,这是因为在凸函数上没有局部极小值的存在,所有的局部极小值都位于全局最小值处。

如上图所示,左上和右下为非凸函数,左下和右上为凸函数。在非凸函数中,有很多条极值点,我们无法直接得到最优解。对于二次可微的函数,我们可以通过判断黑塞矩阵(hessian matrix)是否为半正定的来进行判断,所以我们要对目标函数在点 x 处的二阶偏导数进行求解。

对于我们的式子来说,就是在导函数的基础上再次对 \theta 进行求偏导,由于 X^{T}y 对 \theta 的导数为0,所以再次求偏导后的答案为 X^{T}X。所谓的正定就是答案的特征值全为正数,而半正定无非就是特征值大于等于0。这里我们对损失函数求二阶导的黑塞矩阵是 X^{T}X,自己和自己做点乘,所以答案一定是半正定的。

在此处我们不用深入去讨论数学推导的证明。在机器学习中,损失函数往往是凸函数,在深度学习中的损失函数往往是非凸函数。并且在实际应用当中,我们并不要求找到全局最优解,只要模型适用。机器学习的特点就是不强调模型 100% 正确,而是有价值的,堪用的

三、代码实战求解线性回归算法模型

经过大片的理论讲解,相信大家已经对线性回归模型的实现有了深刻的认识,接下来我们就要通过代码的形式来实战求解线性回归模型

1.导入需要使用的库

import numpy as np
import matplotlib.pyplot as plt

我们需要使用numpy模块进行矩阵之间的运算,最后用matplotlib模块中的绘图功能绘制 X 与 y 的关系图。

2.定义样本集

# 回归,有监督监督机器学习,X,y
X = 2 * np.random.rand(100,1)
y = 5 * 4 * X + np.random.randn(100,1)

这里的 X 是所有 x_{i} 组成的100行1列的矩阵。y 是真实值,5是偏置(截距),4是 X 的权重,后面的 np.random.randn(100,1) 是100行1列以正态分布形成的误差矩阵

3.实现解析解公式求解模型

# 为了求解W0截距项,我们给X矩阵加上一列全为1的X0
X_b = np.c_[np.ones((100,1)),X]

上述代码是通过 np.c_ 的方式将截距项恒为1的权重拼接到 X 矩阵中。

# 实现解析解的公式来求解θ
θ = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
print(θ)
"""[[5.21509616]
    [3.77011339]]
"""

我们将得到的 \theta 的解析解公式通过代码去实现,np.linalg.inv() 是numpy模块中用来计算逆矩阵的函数,X_b.T 是变量 X_b 的转置,.dot() 是进行点乘运算(对numpy模块的讲解在专栏前面的文章当中7.科学计算模块Numpy(4)ndarray数组的常用操作(二)_ndarray逐元素相加-CSDN博客)。我们通过以上代码就可以表示公式:

\theta = (X^{T}X)^{-1}X^{T}y

输出 \theta 之后我们可以看到,截距项与权重都是相当接近真实情况,但由于误差的存在,我们不可能得到真实值,只能拟合数据得到最优解。

4.使用模型去预测

X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]
print(X_new_b)
"""[[1. 0.]
    [1. 2.]]
"""

我们定义一个2行1列的矩阵作为要预测样本的自变量。再加上截距恒为1的项,就会形成一个2行2列的矩阵。

y_predict = X_new_b.dot(θ)
print(y_predict)
"""[[ 5.21509616]
    [12.75532293]]
"""

随后,我们计算预测值,也就是 \hat{y} 。

y_predict = X_new_b.dot(θ)
print(y_predict)
"""[[ 5.21509616]
    [12.75532293]]
"""

只要使用刚刚拼接完的矩阵点乘 \theta 就可以得到预测值了。

5.绘图

plt.plot(X_new, y_predict, 'r-')
plt.plot(X, y, 'b.')
plt.axis([0, 2, 0, 15])
plt.show()

这边我们使用到了matplotlib模块中的绘图功能,成功绘制了下方的坐标图。其中,横轴表示了输入 x 的值,纵轴表示了 y 的值,红线代表了整体的函数,蓝色的点则是真实值的分布情况。我们可以发现,红色的直线尽可能地穿过了蓝色的点,这就是我们一直说的线性回归模型。

总结

这篇博客讲述了模型解析解的推导原理以及代码实现。希望可以对大家起到作用,谢谢。


关注我,内容持续更新(后续内容在作者专栏《从零基础到AI算法工程师》)!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/881973.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

win11永久关闭Windows Defend

# Win11 Microsoft Defender 防病毒 彻底关闭 Win11 Microsoft Defender 防病毒关闭 **WinR****——输入 gpedit.msc ,打开本地组策略编辑器——计算机配置——管理模板——Windows组件——Microsoft Defender 防病毒——关闭 Microsoft Defender 防病毒策略——设置…

免费在线压缩pdf 压缩pdf在线免费 推荐简单好用

压缩pdf在线免费?在日常生活和工作学习中,处理PDF文件是常见任务。但有时PDF文件体积较大,给传输、存储和分享带来不便。因此,学习PDF文件压缩技巧十分必要。压缩PDF文件是指通过技术手段减小文件占用的存储空间,同时尽…

kafka 一步步探究消费者组与分区分配策略

本期主要聊聊kafka消费者组与分区 消费者组 & 消费者 每个消费者都需要归属每个消费者组,每个分区只能被消费者组中一个消费者消费 上面这段话还不够直观,我们举个例子来说明。 订单系统 订单消息通过 order_topic 发送,该topic 有 5个分区 结算系…

基于YOLO算法的网球运动实时分析-击球速度测量-击球次数(附源码)

这个项目通过分析视频中的网球运动员来测量他们的速度、击球速度以及击球次数。该项目使用YOLO(You Only Look Once)算法来检测球员和网球,并利用卷积神经网络(CNNs)来提取球场的关键点。此实战项目非常适合提升您的机…

基于 Web 的工业设备监测系统:非功能性需求与标准化数据访问机制的架构设计

目录 案例 【说明】 【问题 1】(6 分) 【问题 2】(14 分) 【问题 3】(5 分) 【答案】 【问题 1】解析 【问题 2】解析 【问题 3】解析 相关推荐 案例 阅读以下关于 Web 系统架构设计的叙述,回答问题 1 至问题 3 。 【说明】 某公司拟开发一款基于 Web 的…

【JavaEE】多线程编程引入——认识Thread类

阿华代码,不是逆风,就是我疯,你们的点赞收藏是我前进最大的动力!!希望本文内容能帮到你! 目录 引入: 一:Thread类 1:Thread类可以直接调用 2:run方法 &a…

springboot每次都需要重设密码?明明在springboot的配置中设置了密码

第一步:查看当前的密码是什么? 打开redis-cli.exe,输入config get requirepass,查看当前的密码是什么? 接着,修改redis的配置文件,找到redis的安装目录,找到相关的conf文件&#x…

FreeRTOS下UART的封装

FreeRTOS下UART的封装_哔哩哔哩_bilibili Git使用的一个BUG: 当出现这个问题是因为git本身的安全证书路径有问题,我们需要重新指定路径 P1:UART程序层次

【2024】前端学习笔记7-颜色-位置-字体设置

学习笔记 1.定义:css2.颜色:color3.字体相关属性:font3.1.字体大小:font-size3.2.字体风格:font - style3.3.字体粗细:font - weight3.4.字体族:font - family 4.位置:text-align 1.…

K8s容器运行时,移除Dockershim后存在哪些疑惑?

K8s容器运行时,移除Dockershim后存在哪些疑惑? 大家好,我是秋意零。 K8s版本截止目前(24/09)已经发布到了1.31.x版本。早在K8s版本从1.24.x起(22/05),默认的容器运行时就不再是Doc…

最新Kali Linux超详细安装教程(附镜像包)

一、镜像下载: 链接:https://pan.baidu.com/s/1BfiyAMW6E1u9fhfyv8oH5Q 提取码:tft5 二、配置虚拟机 这里我们以最新的vm17.5为例。进行配置 1.创建新的虚拟机:选择自定义 2.下一步 3.选择稍后安装操作系统 4.选择Debian版本 因…

02_RabbitMQ消息丢失解决方案及死信队列

一、数据丢失 第一种:生产者弄丢了数据。生产者将数据发送到 RabbitMQ 的时候,可能数据就在半路给搞丢了,因为网络问题,都有可能。 第二种:RabbitMQ 弄丢了数据。MQ还没有持久化自己挂了。 第三种:消费端…

Vue3新组件transition(动画过渡)

transition组件&#xff1a;控制V-if与V-show的显示与隐藏动画 1.基本使用 <template><div><button click"falg !falg">切换</button><transition name"fade" :enter-to-class"etc"><div v-if"falg&quo…

为什么git有些commit记录,只有git reflog可以看到,git log看不到?

文章目录 原因分析1. git log 只能显示 **可达的** 提交2. git reflog 记录所有引用的变更 常见导致 git log 看不到提交的原因1. git reset 操作2. git rebase 操作3. 分支删除4. git commit --amend5. 垃圾回收&#xff08;GC&#xff09;* 如何恢复 git log 看不到的提交&am…

数据库系统基础概述

文章目录 前言一、数据库基础概念 1.数据库系统的组成2.数据模型3.数据库的体系结构二、MySQL数据库 1.了解MySQL2.MySQL的特性3.MySQL的应用场景总结 前言 MySQL数据库是一款完全免费的产品&#xff0c;用户可以直接从网上下载使用&#xff0c;不用花费任何费用。这点对于初学…

多语言长文本 AI 关键字提取 API 数据接口

多语言长文本 AI 关键字提取 API 数据接口 AI / 文本 专有模型极速提取 多语言长文本 / 实时语料库。 1. 产品功能 支持长文本关键词提取&#xff1b;多语言关键词识别&#xff1b;基于 AI 模型&#xff0c;提取精准关键词&#xff1b;全接口支持 HTTPS&#xff08;TLS v1.0 …

CentOS7更换阿里云yum更新源

目前CentOS内置的更新安装源经常报错无法更新&#xff0c;或者速度不够理想&#xff0c;这个时候更换国内的镜像源就是一个不错的选择。 备份内置更新源 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup 下载阿里云repo源&#xff08;需要系统…

后台数据管理系统 - 项目架构设计-Vue3+axios+Element-plus(0916)

接口文档: https://apifox.com/apidoc/shared-26c67aee-0233-4d23-aab7-08448fdf95ff/api-93850835 接口根路径&#xff1a; http://big-event-vue-api-t.itheima.net 本项目的技术栈 本项目技术栈基于 ES6、vue3、pinia、vue-router 、vite 、axios 和 element-plus http:/…

LeetCode 每周算法 6(图论、回溯)

LeetCode 每周算法 6&#xff08;图论、回溯&#xff09; 图论算法&#xff1a; class Solution: def dfs(self, grid: List[List[str]], r: int, c: int) -> None: """ 深度优先搜索函数&#xff0c;用于遍历并标记与当前位置(r, c)相连的所有陆地&…

HTML讲解(二)head部分

目录 1. 2.的使用 2.1 charset 2.2 name 2.2.1 describe关键字 2.2.2 keywords关键字 2.2.3 author关键字 2.2.4 http-equiv 小心&#xff01;VS2022不可直接接触&#xff0c;否则&#xff01;没这个必要&#xff0c;方源面色淡然一把抓住&#xff01;顷刻炼化&#x…