开发也能看懂的大模型:线性模型

856 阅读6分钟

前情提要:后端开发自学AI知识,内容比较浅显,偏实战;仅适用于入门了解,解决日常沟通障碍。

线性模型是机器学习中最基础且最重要的一类模型,适用于分类、回归等任务。

其核心思想是通过线性关系建模数据输入与输出之间的关系。


1. 什么是线性模型?

线性模型的假设:输出是输入的线性组合

1.1 数学表达

image.png

1.2 分类与回归的区别

  • 回归问题:直接输出 yyy,如预测房价。
  • 分类问题:通过激活函数(如 sigmoid、softmax)将线性模型的输出映射到分类概率。

2. 线性模型的基本形式

image.png

image.png

3. 线性模型的扩展

image.png

4. 优缺点

优点

  1. 简单高效:模型易于训练,计算复杂度低。
  2. 可解释性强:权重 w 反映了每个特征对输出的影响。
  3. 适用场景广:可以处理稀疏数据(如文本分类中的词频矩阵)。

缺点

  1. 局限于线性关系:无法捕获复杂的非线性关系。
  2. 容易受到异常值影响:如线性回归对异常值敏感。
  3. 特征工程依赖大:需要仔细处理输入特征(如归一化、交叉特征)。

5. 应用场景

线性模型在许多实际应用场景中都非常有用,以下是一些典型的例子:

  1. 房价预测

    • 使用线性回归模型,根据房屋的特征(如面积、卧室数量、位置等)来预测房价。这是经典的线性回归应用场景之一。
  2. 市场营销分析

    • 通过线性模型分析广告支出与销售额之间的关系,帮助企业决定如何优化广告预算以最大化收益。
  3. 金融风险管理

    • 使用线性回归来预测股票价格走势或信用评分模型,评估借款人的违约风险。
  4. 医疗领域

    • 利用线性模型根据患者的生理指标(例如血压、胆固醇水平等)预测患病风险或治疗效果。
  5. 生产和质量控制

    • 在制造业中,分析生产参数与产品质量之间的关系,以优化生产过程和提高产品质量。
  6. 社会科学研究

    • 通过线性回归分析调查数据,以确定社会因素(如教育程度、收入水平)对某些结果(如就业率、幸福感)的影响。
  7. 能源消耗预测

    • 根据历史气象数据和使用情况,预测未来的能源需求,以优化电力供应和减少浪费。

这些应用案例展示了线性模型在各行各业中的广泛用途。线性模型之所以如此受欢迎,部分原因在于其易于实现和解释,特别是在需要快速构建原型和验证假设的情况下。

案例概述

假设我们有一组房屋数据集,其中包含多项特征,例如房屋面积、卧室数量、卫生间数量、所在位置等,以及对应的房价。我们的目标是建立一个线性回归模型来预测新房屋的价格。

步骤

  1. 数据收集

    • 首先,需要收集相关数据。这可能来自房地产数据库或公开数据集,比如Kaggle上的房价数据。
  2. 数据预处理

    • 清洗数据:检查数据是否存在缺失值和异常值,并进行相应处理。例如,删除缺失值过多的记录或用平均值/中位数填补。
    • 特征选择与工程:根据经验和数据分析,选择与房价关系较密切的特征。例如,面积、房龄、位置等。
    • 数据标准化:对特征进行标准化或归一化处理,以提高模型收敛速度和性能。
  3. 拆分数据集

    • 将数据集拆分为训练集和测试集(例如80%用于训练,20%用于测试),以评估模型的泛化能力。
  4. 模型训练

    • 使用训练集数据建立线性回归模型。可以利用Python中的scikit-learn库
  5. 模型评估

    • 使用测试数据评估模型性能。通常使用均方误差(MSE)或决定系数(R²)来衡量模型的效果。
  6. 模型优化

    • 如果模型性能不佳,可以考虑特征选择、增加多项式特征、或者使用正则化等方法进行优化。
  7. 模型部署

    • 一旦模型达到满意的性能,可以将其部署在生产环境中,用于实时房价预测。

案例实现

我们将使用一个常见的开源数据集:如Kaggle上的房价预测数据集。下面的例子使用了来自Kaggle的“House Prices - Advanced Regression Techniques”数据集,可以用来进行类似的线性回归分析。

准备工作

首先,确保你已经下载了数据集,并且安装好了所需的Python库:

pip install numpy pandas scikit-learn matplotlib seaborn

假设数据集文件 train.csv 已经下载并放在当前工作目录下:

下载地址:www.kaggle.com/competition…

导入库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

加载数据

# 加载数据集
data = pd.read_csv('train.csv')

# 查看数据前几行
print(data.head())

数据预处理

选择几个简单的特征进行实验,例如,'OverallQual', 'GrLivArea', 'GarageCars', 'TotalBsmtSF'。

# 选择特征和目标值
features = ['OverallQual', 'GrLivArea', 'GarageCars', 'TotalBsmtSF']
X = data[features]
y = data['SalePrice']

# 检查缺失值
print(X.isnull().sum())

# 如果有缺失值需要进行填充,这里假设没有缺失

数据集拆分

# 拆分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

模型训练

# 创建线性回归模型并训练
model = LinearRegression()
model.fit(X_train, y_train)

模型预测与评估

# 进行预测
y_pred = model.predict(X_test)

# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f'Mean Squared Error: {mse}')
print(f'R-squared: {r2}')

image.png

可视化结果

# 绘制预测值与实际值对比
plt.scatter(y_test, y_pred)
plt.xlabel('Actual Prices')
plt.ylabel('Predicted Prices')
plt.title('Actual vs Predicted Prices')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')
plt.savefig('out.png', dpi=300, bbox_inches='tight')
plt.show()

image.png

总结

通过这个例子,你可以学习如何使用线性回归模型对房价进行预测。为了提高模型的效果,可以进一步探索更多的特征,或者使用更高级的模型和特征工程技术。

附录:完整代码


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

# 加载数据集
data = pd.read_csv('train.csv')

# 查看数据前几行
#print(data.head())

# 选择特征和目标值
features = ['OverallQual', 'GrLivArea', 'GarageCars', 'TotalBsmtSF']
X = data[features]
y = data['SalePrice']

# 检查缺失值
#print(X.isnull().sum())

# 拆分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建线性回归模型并训练
model = LinearRegression()
model.fit(X_train, y_train)

# 进行预测
y_pred = model.predict(X_test)

# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f'Mean Squared Error: {mse}')
print(f'R-squared: {r2}')

# 绘制预测值与实际值对比
plt.scatter(y_test, y_pred)
plt.xlabel('Actual Prices')
plt.ylabel('Predicted Prices')
plt.title('Actual vs Predicted Prices')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')
#plt.show()
plt.savefig('output.png')  # 保存为文件