#!/usr/bin/env python3
import os
import re
import subprocess
import sys
import time
import yaml
import json
import argparse
import tiktoken
from openai import OpenAI
from pathlib import Path
from functools import lru_cache
import threading
from queue import Queue, Empty
import socket
import select
import httpx
import concurrent.futures
# 配置路径
CONFIG_PATH = Path.home() / ".ag_config.yaml"
def load_config():
"""加载配置文件,增加详细日志记录"""
if not CONFIG_PATH.exists():
print(f"❌ 配置文件不存在: {CONFIG_PATH}")
print("请创建配置文件,示例内容见文档")
sys.exit(1)
try:
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
# 设置默认值
config.setdefault("model", {})
config["model"].setdefault("remote", {})
config["model"].setdefault("local", {})
config.setdefault("security", {})
config["security"].setdefault("safe_commands", [])
config["security"].setdefault("protected_paths", [])
config["security"].setdefault("dangerous_patterns", [])
config.setdefault("prompts", {})
config.setdefault("history", {})
config.setdefault("optimization", {})
# 打印加载的配置
print(f"🔧 加载配置文件: {CONFIG_PATH}")
print(f" 模型类型: {config['model'].get('type', '未指定')}")
if config["model"].get("type") == "remote":
print(f" 远程模型: {config['model']['remote'].get('model_name', '未指定')}")
elif config["model"].get("type") == "local":
print(f" 本地模型: {config['model']['local'].get('model_name', '未指定')}")
return config
except Exception as e:
print(f"❌ 加载配置文件失败: {str(e)}")
sys.exit(1)
def get_ai_client(config):
"""根据配置创建AI客户端,增加详细日志"""
model_type = config["model"]["type"]
try:
if model_type == "remote":
print("🛰️ 初始化远程AI客户端...")
print(f" API基础URL: {config['model']['remote'].get('base_url', '未指定')}")
print(f" 模型名称: {config['model']['remote'].get('model_name', '未指定')}")
# 创建 httpx 客户端并设置连接池
pool_size = config["optimization"].get("connection_pool_size", 5)
# 创建正确的超时配置
timeout_config = httpx.Timeout(
connect=config["model"]["remote"].get("connect_timeout", 3.0),
read=config["model"]["remote"].get("read_timeout", 20.0),
write=20.0,
pool=5.0
)
http_client = httpx.Client(
limits=httpx.Limits(
max_connections=pool_size,
max_keepalive_connections=pool_size
),
timeout=timeout_config
)
client = OpenAI(
api_key=config["model"]["remote"]["api_key"],
base_url=config["model"]["remote"]["base_url"],
http_client=http_client,
max_retries=config["model"]["remote"].get("max_retries", 2)
)
print("✅ 远程AI客户端初始化成功")
return client
elif model_type == "local":
print("💻 初始化本地AI客户端...")
print(f" API基础URL: {config['model']['local'].get('api_base', '未指定')}")
print(f" 模型名称: {config['model']['local'].get('model_name', '未指定')}")
# 本地模型使用更长的超时时间
timeout_config = httpx.Timeout(
connect=config["model"]["local"].get("connect_timeout", 5.0),
read=config["model"]["local"].get("read_timeout", 60.0),
write=30.0,
pool=5.0
)
client = OpenAI(
api_key="ollama",
base_url=config["model"]["local"]["api_base"],
timeout=timeout_config
)
print("✅ 本地AI客户端初始化成功")
return client
else:
raise ValueError(f"未知模型类型: {model_type}")
except Exception as e:
print(f"❌ AI客户端初始化失败: {str(e)}")
print("请检查配置文件: ~/.ag_config.yaml")
print("提示: 确保API密钥和URL配置正确")
sys.exit(1)
def get_token_encoder(model_name="gpt-3.5-turbo"):
"""获取token编码器,增加日志"""
print(f"🔢 初始化token编码器: {model_name}")
try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
print(f"⚠️ 警告: 模型 {model_name} 没有找到,使用默认编码器")
return tiktoken.get_encoding("cl100k_base")
def num_tokens_from_messages(messages, encoder):
"""计算消息的token数量,增加日志"""
tokens = sum(len(encoder.encode(msg["content"])) for msg in messages)
print(f"📊 消息token数量: {tokens}")
return tokens
def build_messages(question, config, history, args):
"""构建消息列表,增加详细日志"""
print("🧩 构建消息列表...")
# 确定角色
role = args.role if hasattr(args, "role") else "default"
print(f" 使用角色: {role}")
# 获取提示词模板
prompt_template = config["prompts"].get(role, config["prompts"].get("default", "{question}"))
print(f" 提示词模板: {prompt_template}")
# 应用模板
user_content = prompt_template.format(question=question)
print(f" 用户消息内容: {user_content[:100]}...")
messages = [{"role": "user", "content": user_content}]
# 添加历史记录
max_history_items = config["history"].get("max_items", 5)
max_history_tokens = config["history"].get("max_tokens", 500)
max_actual_tokens = config["optimization"].get("max_history_tokens", 300)
print(f" 历史记录设置: 最大条目={max_history_items}, 最大token={max_history_tokens}")
if history:
# 编码器
encoder = get_token_encoder()
# 添加历史记录
history_messages = []
total_tokens = 0
# 从最新到最旧添加历史记录
for msg in reversed(history):
msg_tokens = len(encoder.encode(msg["content"]))
# 检查是否超过token限制
if total_tokens + msg_tokens > max_actual_tokens:
print(f" 达到token限制({max_actual_tokens}),停止添加历史记录")
break
history_messages.insert(0, msg)
total_tokens += msg_tokens
# 组合消息
messages = history_messages + messages
print(f" 添加了 {len(history_messages)} 条历史记录,总token={total_tokens}")
return messages
def stream_response(client, messages, config, model_name):
"""流式获取响应,增加详细日志"""
print("🌊 开始流式获取响应...")
# 获取模型名称
model_type = config["model"]["type"]
actual_model = config["model"][model_type].get("model_name", model_name)
print(f" 使用模型: {actual_model}")
print(f" 发送消息: {len(messages)} 条")
# 创建响应队列
response_queue = Queue()
min_chunk_size = config["optimization"].get("min_chunk_size", 10)
timeout_fallback = config["optimization"].get("timeout_fallback", 3.0)
# 启动流式响应线程
def fetch_response():
try:
stream = client.chat.completions.create(
model=actual_model,
messages=messages,
stream=True,
max_tokens=2000
)
content_chunks = []
for chunk in stream:
if chunk.choices[0].delta.content is not None:
chunk_content = chunk.choices[0].delta.content
content_chunks.append(chunk_content)
response_queue.put(chunk_content)
response_queue.put(None) # 结束标志
print(f"✅ 流式响应完成,接收 {len(content_chunks)} 个块")
except Exception as e:
print(f"❌ 流式响应错误: {str(e)}")
response_queue.put(e)
threading.Thread(target=fetch_response, daemon=True).start()
# 收集响应
response_content = []
buffer = ""
last_chunk_time = time.time()
while True:
try:
# 设置较短的超时时间以实现实时输出
chunk = response_queue.get(timeout=0.1)
if isinstance(chunk, Exception):
raise chunk
if chunk is None:
# 结束标志
if buffer:
print(buffer, end="", flush=True)
response_content.append(buffer)
break
# 添加到缓冲区和完整响应
buffer += chunk
response_content.append(chunk)
# 如果缓冲区达到最小块大小或遇到换行符,则输出
if len(buffer) >= min_chunk_size or "\n" in buffer:
print(buffer, end="", flush=True)
buffer = ""
last_chunk_time = time.time()
except Empty:
# 检查超时回退
if buffer and (time.time() - last_chunk_time > timeout_fallback):
print(buffer, end="", flush=True)
buffer = ""
last_chunk_time = time.time()
# 检查线程是否结束
if not threading.active_count() > 1:
break
# 确保所有内容都已输出
if buffer:
print(buffer, end="", flush=True)
print("\n")
return "".join(response_content)
def generate_response(question, config, history, args):
"""生成响应,增加详细日志"""
print("\n" + "=" * 50)
print(f"💡 开始处理查询: {question}")
print("=" * 50)
# 获取AI客户端
client = get_ai_client(config)
# 构建消息
messages = build_messages(question, config, history, args)
# 获取模型名称
model_type = config["model"]["type"]
model_name = config["model"][model_type].get("model_name", "gpt-3.5-turbo")
# 流式输出响应
print("\n🤖 AI响应:")
response = stream_response(client, messages, config, model_name)
# 添加到历史记录
max_history_items = config["history"].get("max_items", 5)
# 添加用户消息
history.append({"role": "user", "content": question})
# 添加AI响应
history.append({"role": "assistant", "content": response})
# 限制历史记录长度
if len(history) > max_history_items * 2: # 每次对话2条消息
history = history[-(max_history_items * 2):]
return response, history
def is_command_safe(command, config):
"""检查命令是否安全,增加详细日志"""
print(f"🛡️ 检查命令安全性: {command}")
# 检查命令长度
max_length = config["security"].get("max_command_length", 200)
if len(command) > max_length:
print(f"❌ 命令过长 ({len(command)} > {max_length})")
return False
# 检查命令是否在白名单中
safe_commands = config["security"].get("safe_commands", [])
first_word = command.split()[0] if command else ""
if first_word and first_word not in safe_commands:
print(f"❌ 命令 '{first_word}' 不在安全命令白名单中")
return False
# 检查危险模式
dangerous_patterns = config["security"].get("dangerous_patterns", [])
for pattern in dangerous_patterns:
if re.search(pattern, command):
print(f"❌ 检测到危险模式: {pattern}")
return False
# 检查受保护路径
protected_paths = config["security"].get("protected_paths", [])
for path in protected_paths:
if path in command:
print(f"❌ 命令包含受保护路径: {path}")
return False
print("✅ 命令安全检查通过")
return True
def execute_command(command, config):
"""执行命令,增加详细日志"""
print(f"🚀 执行命令: {command}")
try:
# 设置超时时间
timeout = config["security"].get("max_execution_time", 15)
print(f" 设置执行超时: {timeout}秒")
# 执行命令
start_time = time.time()
result = subprocess.run(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=timeout
)
elapsed = time.time() - start_time
print(f"✅ 命令执行完成 (耗时: {elapsed:.2f}秒)")
# 返回执行结果
return {
"success": True,
"returncode": result.returncode,
"stdout": result.stdout,
"stderr": result.stderr
}
except subprocess.TimeoutExpired:
print(f"❌ 命令执行超时 ({timeout}秒)")
return {
"success": False,
"error": f"命令执行超时 ({timeout}秒)"
}
except Exception as e:
print(f"❌ 命令执行出错: {str(e)}")
return {
"success": False,
"error": str(e)
}
def extract_command(response):
"""从响应中提取命令,增加详细日志"""
print("🔍 从响应中提取命令...")
# 尝试匹配代码块
code_match = re.search(r"```(?:bash|shell)?\n(.+?)\n```", response, re.DOTALL)
if code_match:
command = code_match.group(1).strip()
print(f" 从代码块中提取命令: {command}")
return command
# 尝试匹配单行命令
command_match = re.search(r"^\s*`(.+?)`\s*$", response, re.MULTILINE)
if command_match:
command = command_match.group(1).strip()
print(f" 从内联代码中提取命令: {command}")
return command
# 尝试匹配直接命令
direct_match = re.search(r"^\s*(.+?)\s*$", response)
if direct_match:
command = direct_match.group(1).strip()
print(f" 从纯文本中提取命令: {command}")
return command
print("⚠️ 未找到有效命令")
return None
def parse_args():
"""解析命令行参数,增加详细日志"""
print("🔍 解析命令行参数...")
parser = argparse.ArgumentParser(description="AI命令行助手")
parser.add_argument("query", nargs="*", help="要查询的问题或命令")
parser.add_argument("-c", "--command", action="store_true", help="生成并执行命令")
parser.add_argument("-r", "--role", choices=["default", "engineer", "strict", "command"],
default="default", help="设置AI角色")
args = parser.parse_args()
# 如果没有提供查询,从标准输入读取
if not args.query:
print(" 从标准输入读取查询...")
args.query = [sys.stdin.read().strip()]
args.query = " ".join(args.query)
print(f" 查询内容: {args.query}")
print(f" 命令模式: {'是' if args.command else '否'}")
print(f" AI角色: {args.role}")
return args
def main():
"""主函数,增加详细日志"""
print("=" * 60)
print("🤖 AI命令行助手 - 启动")
print("=" * 60)
# 加载配置
config = load_config()
# 解析参数
args = parse_args()
question = args.query
# 初始化历史记录
history = []
# 如果是命令模式
if args.command:
print("\n🔧 进入命令生成模式")
args.role = "command" # 强制使用命令角色
# 生成响应
response, history = generate_response(question, config, history, args)
# 提取命令
command = extract_command(response)
if not command:
print("❌ 无法从响应中提取有效命令")
sys.exit(1)
# 检查命令安全性
if not is_command_safe(command, config):
print("❌ 命令被阻止执行,安全策略不允许")
sys.exit(1)
# 执行命令
print("\n⚡ 执行命令:")
print(f"$ {command}")
result = execute_command(command, config)
# 显示结果
if result["success"]:
if result["stdout"]:
print("\n✅ 命令输出:")
print(result["stdout"])
if result["stderr"]:
print("\n⚠️ 命令错误:")
print(result["stderr"])
print(f"\n命令退出码: {result['returncode']}")
else:
print(f"\n❌ 命令执行失败: {result['error']}")
else:
# 普通查询模式
print("\n❓ 进入查询回答模式")
response, history = generate_response(question, config, history, args)
print("\n✅ 回答完成")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n🛑 程序被用户中断")
sys.exit(0)
except Exception as e:
print(f"\n❌ 发生未处理错误: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
yaml
model:
# 可选: remote (远程API) 或 local (本地模型)
type: remote
# 远程API配置 (DeepSeek/OpenAI等)
remote:
#api_key: "sk-b9b76d7809f54150ab1a8105eedfae3a" # 替换为您的API密钥
#base_url: "https://api.deepseek.com/v1" # DeepSeek API地址
api_key: "sk-" # 替换为您的API密钥
base_url: "https://api.siliconflow.cn/v1/" # DeepSeek API地址
model_name: "Pro/deepseek-ai/DeepSeek-R1"
connect_timeout: 3.0 # 连接超时时间(秒)
read_timeout: 20.0 # 读取超时时间(秒)
max_retries: 2 # 最大重试次数
# 本地模型配置 (兼容Ollama)
local:
model_path: "~/.ollama/models/manifests/registry.ollama.ai/library/llama3"
api_base: "http://localhost:11434/v1"
model_name: "llama3"
connect_timeout: 5.0 # 本地连接超时
read_timeout: 60.0 # 本地读取超时(本地模型可能较慢)
security:
safe_commands: # 允许执行的命令白名单
- ls
- pwd
- cat
- mkdir
- touch
- echo
- grep
- head
- tail
- df
- du
- ps
- whoami
- date
- find
- cp
- mv
- chmod
- tar
- gzip
- wc
- #
protected_paths: # 受保护的系统目录
- /etc
- /bin
- /sbin
- /usr
- /root
- /var
max_command_length: 200 # 最大命令长度
max_execution_time: 15 # 命令最大执行时间(秒)
dangerous_patterns: # 危险命令模式(正则表达式)
- 'rm\s+-rf'
- 'sudo'
- 'dd'
- 'shutdown'
- 'reboot'
- '^>'
- '>>'
- '\|\s*sh\s*$'
- 'chmod\s+[0-7]{3,4}\s+'
- 'chown\s+root'
- '^mv\s+.*\s+/'
- '^cp\s+.*\s+/'
- ';\s*$'
- '&&\s*$'
- '\|\|'
- '`.*`'
prompts:
default: "你是一个有帮助的AI助手,请回答用户问题:{question}"
engineer: "你是一名资深程序员,请用专业术语回答,解释部分需要使用中文:{question}"
strict: "请用最简洁的语言回答:{question}"
command: "你是一个命令行专家,只返回安全的终端命令:{question}"
history:
max_items: 10 # 最大历史对话条数(每次对话2条)
max_tokens: 1000 # 历史对话最大token数
optimization:
max_history_tokens: 300
min_chunk_size: 10
timeout_fallback: 3.0
connection_pool_size: 5 # HTTP连接池大小
request_retries: 1
使用方式
chmod +x ag_terminal.py
echo "alias ag='python3 ~/ag_terminal.py'" >> ~/.bashrc\nsource ~/.bashrc
ai 当前的一些问题: 1、容易失忆(忘记之前写的代码 做的功能 提的要求) 2、本地调用 api 延迟高于网页版 3、输入输出没有按照规则来 4、随着代码输入越来越多 处理耗时长。
好用的 prompt 1、定义输入输出格式 2、让 ai 自我提出质疑 并且改进 3、ai角色扮演。 4、反问 ai 再反问。
请尝试在下次对话时使用需求快照模板,我会严格遵守锚点锁定区域,并在每次修改前声明变更范围和理由。