输入植物照片,用图像识别判断植物种类,给出浇水施肥建议养活盆栽的小白。

25 阅读7分钟

我将为您创建一个基于深度学习的植物识别与养护建议系统。这个项目将使用PyTorch和预训练的ResNet模型来实现图像分类功能。

项目结构

plant_care_ai/ ├── main.py # 主程序入口 ├── plant_classifier.py # 植物识别模块 ├── care_advisor.py # 养护建议模块 ├── utils.py # 工具函数 ├── requirements.txt # 依赖包列表 ├── README.md # 项目说明 └── knowledge_cards.md # 核心知识点卡片

  1. requirements.txt

torch>=1.9.0 torchvision>=0.10.0 Pillow>=8.0.0 opencv-python>=4.5.0 numpy>=1.21.0 matplotlib>=3.5.0 scikit-learn>=1.0.0 flask>=2.0.0

  1. utils.py - 工具函数模块

import os import json import torch from PIL import Image import numpy as np from torchvision import transforms import cv2

class Config: """配置类""" # 模型参数 MODEL_PATH = "models/plant_classifier.pth" NUM_CLASSES = 10 # 支持的植物种类数量

# 图像预处理参数
IMG_SIZE = 224
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# 植物类别(示例)
PLANT_CLASSES = [
    "绿萝", "吊兰", "仙人掌", "多肉植物", "发财树",
    "君子兰", "虎皮兰", "文竹", "芦荟", "富贵竹"
]

# 数据路径
DATA_DIR = "data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")

def get_image_transforms(): """获取图像预处理转换""" return transforms.Compose([ transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=Config.MEAN, std=Config.STD) ])

def preprocess_image(image_path): """ 预处理输入图像

Args:
    image_path: 图像文件路径
    
Returns:
    预处理后的张量
"""
try:
    # 加载图像
    image = Image.open(image_path).convert('RGB')
    
    # 应用转换
    transform = get_image_transforms()
    processed_image = transform(image).unsqueeze(0)  # 添加batch维度
    
    return processed_image, image
except Exception as e:
    print(f"图像预处理错误: {e}")
    return None, None

def save_prediction_result(predictions, image_path, output_dir="results"): """ 保存预测结果到JSON文件

Args:
    predictions: 预测结果字典
    image_path: 原始图像路径
    output_dir: 输出目录
"""
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 构建结果文件名
base_name = os.path.splitext(os.path.basename(image_path))[0]
result_file = os.path.join(output_dir, f"{base_name}_result.json")

# 添加时间戳
from datetime import datetime
predictions['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
predictions['image_path'] = image_path

# 保存到JSON文件
with open(result_file, 'w', encoding='utf-8') as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"结果已保存至: {result_file}")

def create_sample_data_structure(): """创建示例数据目录结构""" directories = [ Config.TRAIN_DIR, Config.TEST_DIR, os.path.join(Config.DATA_DIR, "validation") ]

for directory in directories:
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"创建目录: {directory}")

# 创建示例标签文件
create_sample_labels()

def create_sample_labels(): """创建示例标签文件""" labels = { "categories": Config.PLANT_CLASSES, "description": "常见室内植物分类标签" }

labels_file = os.path.join(Config.DATA_DIR, "labels.json")
with open(labels_file, 'w', encoding='utf-8') as f:
    json.dump(labels, f, ensure_ascii=False, indent=2)

print(f"标签文件已创建: {labels_file}")

3. plant_classifier.py - 植物识别模块

import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import numpy as np import json from utils import Config, preprocess_image, save_prediction_result

class PlantClassifier: """植物分类器类"""

def __init__(self, model_path=None, num_classes=None):
    """
    初始化植物分类器
    
    Args:
        model_path: 预训练模型路径
        num_classes: 分类类别数
    """
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {self.device}")
    
    self.num_classes = num_classes or Config.NUM_CLASSES
    self.class_names = Config.PLANT_CLASSES
    
    # 加载模型
    self.model = self._load_model(model_path)
    self.model.to(self.device)
    self.model.eval()
    
    # 图像预处理
    self.transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])

def _load_model(self, model_path):
    """
    加载预训练模型
    
    Args:
        model_path: 模型文件路径
        
    Returns:
        加载的模型
    """
    # 使用ResNet18作为基础模型
    model = models.resnet18(pretrained=True)
    
    # 修改最后一层全连接层
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, self.num_classes)
    
    # 如果提供了模型路径,则加载权重
    if model_path and os.path.exists(model_path):
        try:
            checkpoint = torch.load(model_path, map_location=self.device)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"成功加载模型: {model_path}")
        except Exception as e:
            print(f"加载模型失败: {e}")
            print("使用预训练权重继续...")
    else:
        print("未找到模型文件,使用预训练权重")
    
    return model

def predict_single_image(self, image_path, topk=3):
    """
    对单张图像进行预测
    
    Args:
        image_path: 图像文件路径
        topk: 返回前k个预测结果
        
    Returns:
        预测结果字典
    """
    # 预处理图像
    processed_image, original_image = preprocess_image(image_path)
    if processed_image is None:
        return {"error": "图像预处理失败"}
    
    # 移动到设备
    processed_image = processed_image.to(self.device)
    
    # 预测
    with torch.no_grad():
        outputs = self.model(processed_image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_probs, top_indices = torch.topk(probabilities, topk, dim=1)
    
    # 解析结果
    results = []
    for i in range(topk):
        class_idx = top_indices[0][i].item()
        prob = top_probs[0][i].item()
        
        if class_idx < len(self.class_names):
            class_name = self.class_names[class_idx]
        else:
            class_name = f"未知类别_{class_idx}"
        
        results.append({
            "class": class_name,
            "confidence": float(prob),
            "probability": f"{prob:.4f}"
        })
    
    prediction_result = {
        "predicted_class": results[0]["class"],
        "confidence": results[0]["confidence"],
        "top_predictions": results,
        "all_classes": self.class_names
    }
    
    # 保存结果
    save_prediction_result(prediction_result, image_path)
    
    return prediction_result

def predict_batch_images(self, image_paths, topk=3):
    """
    批量预测多张图像
    
    Args:
        image_paths: 图像路径列表
        topk: 返回前k个预测结果
        
    Returns:
        批量预测结果
    """
    batch_results = []
    
    for image_path in image_paths:
        result = self.predict_single_image(image_path, topk)
        batch_results.append(result)
    
    return batch_results

def train_model(self, train_loader, val_loader, epochs=10, lr=0.001):
    """
    训练模型(简化版本)
    
    Args:
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        epochs: 训练轮数
        lr: 学习率
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
    
    train_losses = []
    val_accuracies = []
    
    print("开始训练模型...")
    
    for epoch in range(epochs):
        # 训练阶段
        self.model.train()
        running_loss = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            
            optimizer.zero_grad()
            outputs = self.model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # 验证阶段
        val_accuracy = self.evaluate_model(val_loader)
        
        train_losses.append(running_loss / len(train_loader))
        val_accuracies.append(val_accuracy)
        
        print(f'Epoch [{epoch+1}/{epochs}], '
              f'Loss: {running_loss/len(train_loader):.4f}, '
              f'Val Accuracy: {val_accuracy:.4f}')
    
    # 保存训练好的模型
    self.save_model("models/trained_plant_classifier.pth")
    print("模型训练完成并已保存")

def evaluate_model(self, data_loader):
    """评估模型性能"""
    self.model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            outputs = self.model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total

def save_model(self, path):
    """保存模型"""
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    
    torch.save({
        'model_state_dict': self.model.state_dict(),
        'num_classes': self.num_classes,
        'class_names': self.class_names
    }, path)
    print(f"模型已保存至: {path}")

示例使用

if name == "main": # 创建分类器实例 classifier = PlantClassifier()

# 测试预测功能(如果有测试图像)
test_image = "sample_plants/test_plant.jpg"
if os.path.exists(test_image):
    result = classifier.predict_single_image(test_image)
    print("预测结果:", result)
else:
    print("请将测试图像放在 sample_plants/ 目录下")

4. care_advisor.py - 养护建议模块

import json from datetime import datetime, timedelta import random

class CareAdvisor: """植物养护建议专家系统"""

def __init__(self):
    """初始化养护建议数据库"""
    self.care_database = self._load_care_database()
    self.seasonal_factors = self._load_seasonal_factors()

def _load_care_database(self):
    """加载植物养护数据库"""
    care_db = {
        "绿萝": {
            "watering": {
                "frequency": "3-4天",
                "amount": "适量",
                "method": "见干见湿",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 1.5,
                    "autumn": 0.8,
                    "winter": 0.5
                }
            },
            "fertilizing": {
                "frequency": "2周一次",
                "type": "液体复合肥",
                "dosage": "稀释1000倍",
                "seasonal_adjustment": {
                    "spring": 1.2,
                    "summer": 1.0,
                    "autumn": 0.8,
                    "winter": 0.3
                }
            },
            "light": "散射光,避免直射",
            "temperature": "15-25°C",
            "humidity": "60-80%",
            "tips": [
                "喜欢湿润环境,可经常喷雾",
                "生长迅速,定期修剪",
                "对甲醛有一定吸收作用"
            ]
        },
        "吊兰": {
            "watering": {
                "frequency": "5-7天",
                "amount": "适中",
                "method": "表土干燥后浇水",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 1.3,
                    "autumn": 0.9,
                    "winter": 0.4
                }
            },
            "fertilizing": {
                "frequency": "3周一次",
                "type": "通用液肥",
                "dosage": "薄肥勤施",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 0.8,
                    "autumn": 0.6,
                    "winter": 0.2
                }
            },
            "light": "半阴环境,避免强光",
            "temperature": "10-25°C",
            "humidity": "50-70%",
            "tips": [
                "容易繁殖,可分株种植",
                "能净化空气",
                "适应性强,新手友好"
            ]
        },
        "仙人掌": {
            "watering": {
                "frequency": "10-15天",
                "amount": "少量",
                "method": "干透浇透",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 1.2,
                    "autumn": 0.7,
                    "winter": 0.1
                }
            },
            "fertilizing": {
                "frequency": "1个月一次",
                "type": "专用仙人掌肥",
                "dosage": "极稀薄",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 0.8,
                    "autumn": 0.5,
                    "winter": 0.0
                }
            },
            "light": "充足阳光",
            "temperature": "15-30°C",
            "humidity": "30-50%",
            "tips": [
                "冬季几乎不需要浇水",
                "需要良好排水",
                "避免频繁移动位置"
            ]
        },
        "多肉植物": {
            "watering": {
                "frequency": "7-10天",
                "amount": "少量多次",
                "method": "叶片略软时浇水",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 0.8,
                    "autumn": 1.2,
                    "winter": 0.3
                }
            },
            "fertilizing": {
                "frequency": "2个月一次",
                "type": "缓释肥或多肉专用肥",
                "dosage": "少量",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 0.6,
                    "autumn": 1.0,
                    "winter": 0.0
                }
            },
            "light": "明亮光照,夏季遮阴",
            "temperature": "10-28°C",
            "humidity": "40-60%",
            "tips": [
                "春秋季是生长期",
                "夏季注意通风降温",
                "配土要疏松透气"
            ]
        },
        "发财树": {
            "watering": {
                "frequency": "7-10天",
                "amount": "适量",
                "method": "表土2cm干燥后浇水",
                "seasonal_adjustment": {
                    "spring": 1.0,
                    "summer": 1.4,
                    "autumn": 0.8,
                    "winter": 0.4
                }
            },
            "fertilizing": {
                "frequency": "1个月一次",
                "type": "观叶植物肥",
                "dosage": "按说明书减半",
                "seasonal_adjustment": {
                    "spring": 1.2,
                    "summer": 1.0,
                    "autumn": 0.6,
                    "winter": 0.2
                }
            },
            "light": "明亮散射光",
            "temperature": "18-25°C",
            "humidity": "50-60%",
            "tips": [
                "不耐积水,注意排水",
                "可以修剪保持造型",
                "寓意吉祥,适合客厅"
            ]
        }
    }
    return care_db

def _load_seasonal_factors(self):
    """加载季节性因素"""
    current_month = datetime.now().month
    if current_month in [3, 4, 5]:
        season = "spring"
    elif current_month in [6, 7, 8]:
        season = "summer"
    elif current_month in [9, 10, 11]:
        season = "autumn"
    else:
        season = "winter"
    
    return {
        "current_season": season,
        "current_month": current_month,
        "next_watering_days": self._calculate_next_watering(season)
    }

def _calculate_next_watering(self, season):
    """根据季节计算下次浇水天数"""
    base_days = {
        "spring": 5,
        "summer": 3,
        "autumn": 7,
        "winter": 12
    }
    return base_days.get(season, 7)

def get_care_advice(self, plant_name, confidence=None):
    """
    获取植物的养护建议
    
    Args:
        plant_name: 植物名称
        confidence: 识别置信度
        
    Returns:
        详细的养护建议
    """
    if plant_name not in self.care_database:
        return self._get_default_advice(plant_name)
    
    plant_info = self.care_database[plant_name]
    season = self.seasonal_factors["current_season"]
    
    # 根据季节调整建议
    adjusted_advice = self._adjust_for_season(plant_info, season)
    
    # 构建完整建议
    advice = {
        "plant_name": plant_name,
        "identification_confidence": confidence,
        "current_season": season,
        "watering_advice": adjusted_advice["watering"],
        "fertilizing_advice": adjusted_advice["fertilizing"],
        "environmental_needs": {
            "light": plant_info["light"],
            "temperature": plant_info["temperature"],
            "humidity": plant_info["humidity"]
        },
        "care_tips": plant_info["tips"],
        "next_actions": self._generate_next_actions(plant_name, season),
        "warning_signs": self._get_warning_signs(plant_name),
        "generated_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    return advice

def _adjust_for_season(self, plant_info, season):
    """根据季节调整养护建议"""
    watering = plant_info["watering"].copy()
    fertilizing = plant_info["fertilizing"].copy()
    
    # 调整浇水频率
    base_freq = watering["frequency"]
    adjustment = watering["seasonal_adjustment"][season]
    adjusted_freq = self._adjust_frequency(base_freq, adjustment)
    watering["adjusted_frequency"] = adjusted_freq
    watering["seasonal_note"] = f"当前{season}季节,浇水频率{'增加' if adjustment > 1 else '减少'}"
    
    # 调整施肥频率
    base_freq_fert = fertilizing["frequency"]
    adjustment_fert = fertilizing["seasonal_adjustment"][season]
    adjusted_freq_fert = self._adjust_frequency(base_freq_fert, adjustment_fert)
    fertilizing["adjusted_frequency"] = adjusted_freq_fert
    fertilizing["seasonal_note"] = f"当前{season}季节,施肥频率{'增加' if adjustment_fert > 1 else '减少'}"
    
    return {
        "watering": watering,
        "fertilizing": fertilizing
    }

def _adjust_frequency(self, frequency_str, adjustment):
    """调整频率字符串"""
    try:
        # 解析频率字符串中的数字
        import re
        numbers = re.findall(r'\d+', frequency_str)
        if numbers:
            base_num = int(numbers[0])
            adjusted_num = int(base_num * adjustment)
            
            if "天" in frequency_str:
                return f"{adjusted_num}天"
            elif "周" in frequency_str:
                weeks = adjusted_num // 7
                days = adjusted_num % 7
                if days == 0:
                    return f"{weeks}周"
                e

关注我,有更多实用程序等着你!