展示如何创建自定义工具并集成到 Agent 中。
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/astercloud/aster/pkg/agent"
"github.com/astercloud/aster/pkg/provider"
"github.com/astercloud/aster/pkg/sandbox"
"github.com/astercloud/aster/pkg/store"
"github.com/astercloud/aster/pkg/tools"
"github.com/astercloud/aster/pkg/types"
)
// 1. 定义天气工具
func WeatherTool() tools.Tool {
return tools.Tool{
Name: "get_weather",
Description: "获取指定城市的实时天气信息(温度、天气状况、湿度)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{
"type": "string",
"description": "城市名称,例如:北京、上海、深圳",
},
"unit": map[string]interface{}{
"type": "string",
"enum": []string{"celsius", "fahrenheit"},
"description": "温度单位,默认 celsius",
"default": "celsius",
},
},
"required": []string{"city"},
},
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
// 解析参数
city, ok := tc.Input["city"].(string)
if !ok {
return nil, fmt.Errorf("缺少必需参数: city")
}
unit := "celsius"
if u, ok := tc.Input["unit"].(string); ok {
unit = u
}
// 调用天气 API(简化示例)
weather, err := fetchWeather(city, unit)
if err != nil {
return nil, fmt.Errorf("获取天气失败: %w", err)
}
return weather, nil
},
}
}
// 模拟天气 API 调用
func fetchWeather(city, unit string) (map[string]interface{}, error) {
// 实际应用中,这里应该调用真实的天气 API
// 例如: OpenWeatherMap, WeatherAPI 等
// 模拟数据
temp := 22.5
if unit == "fahrenheit" {
temp = temp*9/5 + 32
}
return map[string]interface{}{
"city": city,
"temperature": temp,
"condition": "晴朗",
"humidity": 65,
"unit": unit,
"timestamp": time.Now().Format(time.RFC3339),
}, nil
}
// 2. 定义数据库查询工具
func DatabaseTool() tools.Tool {
return tools.Tool{
Name: "query_users",
Description: "查询用户数据库,获取用户信息",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"get", "list", "count"},
},
"user_id": map[string]interface{}{
"type": "string",
},
"limit": map[string]interface{}{
"type": "integer",
"default": 10,
},
},
"required": []string{"action"},
},
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
action := tc.Input["action"].(string)
// 模拟数据库查询
switch action {
case "count":
return map[string]interface{}{"count": 42}, nil
case "list":
limit := 10
if l, ok := tc.Input["limit"].(float64); ok {
limit = int(l)
}
return map[string]interface{}{
"users": []string{"user1", "user2", "user3"},
"total": limit,
}, nil
case "get":
userID, _ := tc.Input["user_id"].(string)
return map[string]interface{}{
"id": userID,
"name": "John Doe",
"email": "john@example.com",
}, nil
default:
return nil, fmt.Errorf("未知操作: %s", action)
}
},
}
}
// 3. 定义计算工具
func CalculatorTool() tools.Tool {
return tools.Tool{
Name: "calculator",
Description: "执行数学计算(加减乘除、幂运算)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"expression": map[string]interface{}{
"type": "string",
"description": "数学表达式,例如: 2+3*4",
},
},
"required": []string{"expression"},
},
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
expr := tc.Input["expression"].(string)
// 简化示例:实际应使用 expr 解析库
// 这里只做简单演示
result := 0.0
fmt.Sscanf(expr, "%f", &result)
return map[string]interface{}{
"expression": expr,
"result": result,
}, nil
},
}
}
func main() {
ctx := context.Background()
// 创建工具注册表
toolRegistry := tools.NewRegistry()
// 注册自定义工具
toolRegistry.Register(WeatherTool())
toolRegistry.Register(DatabaseTool())
toolRegistry.Register(CalculatorTool())
// 创建依赖
deps := &agent.Dependencies{
ToolRegistry: toolRegistry,
SandboxFactory: sandbox.NewFactory(),
ProviderFactory: provider.NewMultiProviderFactory(),
Store: store.NewMemoryStore(),
TemplateRegistry: agent.NewTemplateRegistry(),
}
// 创建 Agent
ag, err := agent.Create(ctx, &types.AgentConfig{
TemplateID: "assistant",
ModelConfig: &types.ModelConfig{
Provider: "anthropic",
Model: "claude-sonnet-4-5",
APIKey: os.Getenv("ANTHROPIC_API_KEY"),
},
Tools: []string{"get_weather", "query_users", "calculator"},
}, deps)
if err != nil {
log.Fatal(err)
}
defer ag.Close()
// 测试自定义工具
tasks := []string{
"查询北京的天气",
"数据库中有多少用户?",
"计算 (15 + 25) * 2",
}
for i, task := range tasks {
fmt.Printf("\n========== 任务 %d ==========\n", i+1)
fmt.Printf("请求: %s\n\n", task)
result, err := ag.Chat(ctx, task)
if err != nil {
log.Printf("错误: %v", err)
continue
}
fmt.Printf("响应: %s\n", result.Message.Content)
}
}
export ANTHROPIC_API_KEY="sk-ant-xxx"
go run main.go
========== 任务 1 ==========
请求: 查询北京的天气
响应: 北京当前天气:
- 温度:22.5°C
- 天气:晴朗
- 湿度:65%
========== 任务 2 ==========
请求: 数据库中有多少用户?
响应: 数据库中目前有 42 位用户。
========== 任务 3 ==========
请求: 计算 (15 + 25) * 2
响应: (15 + 25) * 2 = 80
Description: "获取指定城市的实时天气信息(温度、天气状况、湿度)"
// ✅ 清晰说明工具功能,包含关键信息
Description: "获取天气"
// ❌ 过于简单,LLM 不知道何时使用
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{
"type": "string",
"description": "城市名称,例如:北京、上海", // 提供示例
},
"unit": map[string]interface{}{
"type": "string",
"enum": []string{"celsius", "fahrenheit"}, // 限制值
"default": "celsius", // 提供默认值
},
},
"required": []string{"city"}, // 标记必需参数
}
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
// 验证输入
city, ok := tc.Input["city"].(string)
if !ok || city == "" {
return nil, fmt.Errorf("city 参数必需且不能为空")
}
// 检查上下文取消
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// 执行操作
result, err := fetchData(city)
if err != nil {
return nil, fmt.Errorf("操作失败: %w", err)
}
return result, nil
}
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
// 设置超时
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// 执行长时间操作
result, err := longRunningOperation(ctx)
return result, err
}
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
// 打开资源
conn, err := openConnection()
if err != nil {
return nil, err
}
defer conn.Close() // 确保关闭
// 使用资源
result, err := conn.Query()
return result, err
}
func AsyncTool() tools.Tool {
return tools.Tool{
Name: "async_task",
Description: "启动异步任务并返回任务 ID",
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
taskID := uuid.New().String()
// 在后台执行
go func() {
// 执行长时间任务
time.Sleep(10 * time.Second)
// 保存结果到数据库或缓存
}()
return map[string]interface{}{
"task_id": taskID,
"status": "started",
}, nil
},
}
}
type StatefulTool struct {
cache map[string]interface{}
mu sync.RWMutex
}
func (t *StatefulTool) Tool() tools.Tool {
return tools.Tool{
Name: "cache_get",
Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
key := tc.Input["key"].(string)
t.mu.RLock()
defer t.mu.RUnlock()
if value, ok := t.cache[key]; ok {
return value, nil
}
return nil, fmt.Errorf("key not found: %s", key)
},
}
}