Tools 示例

自定义工具

创建和使用自定义工具

自定义工具示例

展示如何创建自定义工具并集成到 Agent 中。

完整代码

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "os"
    "time"

    "github.com/astercloud/aster/pkg/agent"
    "github.com/astercloud/aster/pkg/provider"
    "github.com/astercloud/aster/pkg/sandbox"
    "github.com/astercloud/aster/pkg/store"
    "github.com/astercloud/aster/pkg/tools"
    "github.com/astercloud/aster/pkg/types"
)

// 1. 定义天气工具
func WeatherTool() tools.Tool {
    return tools.Tool{
        Name:        "get_weather",
        Description: "获取指定城市的实时天气信息(温度、天气状况、湿度)",
        InputSchema: map[string]interface{}{
            "type": "object",
            "properties": map[string]interface{}{
                "city": map[string]interface{}{
                    "type":        "string",
                    "description": "城市名称,例如:北京、上海、深圳",
                },
                "unit": map[string]interface{}{
                    "type":        "string",
                    "enum":        []string{"celsius", "fahrenheit"},
                    "description": "温度单位,默认 celsius",
                    "default":     "celsius",
                },
            },
            "required": []string{"city"},
        },
        Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
            // 解析参数
            city, ok := tc.Input["city"].(string)
            if !ok {
                return nil, fmt.Errorf("缺少必需参数: city")
            }

            unit := "celsius"
            if u, ok := tc.Input["unit"].(string); ok {
                unit = u
            }

            // 调用天气 API(简化示例)
            weather, err := fetchWeather(city, unit)
            if err != nil {
                return nil, fmt.Errorf("获取天气失败: %w", err)
            }

            return weather, nil
        },
    }
}

// 模拟天气 API 调用
func fetchWeather(city, unit string) (map[string]interface{}, error) {
    // 实际应用中,这里应该调用真实的天气 API
    // 例如: OpenWeatherMap, WeatherAPI 等

    // 模拟数据
    temp := 22.5
    if unit == "fahrenheit" {
        temp = temp*9/5 + 32
    }

    return map[string]interface{}{
        "city":        city,
        "temperature": temp,
        "condition":   "晴朗",
        "humidity":    65,
        "unit":        unit,
        "timestamp":   time.Now().Format(time.RFC3339),
    }, nil
}

// 2. 定义数据库查询工具
func DatabaseTool() tools.Tool {
    return tools.Tool{
        Name:        "query_users",
        Description: "查询用户数据库,获取用户信息",
        InputSchema: map[string]interface{}{
            "type": "object",
            "properties": map[string]interface{}{
                "action": map[string]interface{}{
                    "type": "string",
                    "enum": []string{"get", "list", "count"},
                },
                "user_id": map[string]interface{}{
                    "type": "string",
                },
                "limit": map[string]interface{}{
                    "type":    "integer",
                    "default": 10,
                },
            },
            "required": []string{"action"},
        },
        Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
            action := tc.Input["action"].(string)

            // 模拟数据库查询
            switch action {
            case "count":
                return map[string]interface{}{"count": 42}, nil
            case "list":
                limit := 10
                if l, ok := tc.Input["limit"].(float64); ok {
                    limit = int(l)
                }
                return map[string]interface{}{
                    "users": []string{"user1", "user2", "user3"},
                    "total": limit,
                }, nil
            case "get":
                userID, _ := tc.Input["user_id"].(string)
                return map[string]interface{}{
                    "id":    userID,
                    "name":  "John Doe",
                    "email": "john@example.com",
                }, nil
            default:
                return nil, fmt.Errorf("未知操作: %s", action)
            }
        },
    }
}

// 3. 定义计算工具
func CalculatorTool() tools.Tool {
    return tools.Tool{
        Name:        "calculator",
        Description: "执行数学计算(加减乘除、幂运算)",
        InputSchema: map[string]interface{}{
            "type": "object",
            "properties": map[string]interface{}{
                "expression": map[string]interface{}{
                    "type":        "string",
                    "description": "数学表达式,例如: 2+3*4",
                },
            },
            "required": []string{"expression"},
        },
        Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
            expr := tc.Input["expression"].(string)

            // 简化示例:实际应使用 expr 解析库
            // 这里只做简单演示
            result := 0.0
            fmt.Sscanf(expr, "%f", &result)

            return map[string]interface{}{
                "expression": expr,
                "result":     result,
            }, nil
        },
    }
}

func main() {
    ctx := context.Background()

    // 创建工具注册表
    toolRegistry := tools.NewRegistry()

    // 注册自定义工具
    toolRegistry.Register(WeatherTool())
    toolRegistry.Register(DatabaseTool())
    toolRegistry.Register(CalculatorTool())

    // 创建依赖
    deps := &agent.Dependencies{
        ToolRegistry:     toolRegistry,
        SandboxFactory:   sandbox.NewFactory(),
        ProviderFactory:  provider.NewMultiProviderFactory(),
        Store:            store.NewMemoryStore(),
        TemplateRegistry: agent.NewTemplateRegistry(),
    }

    // 创建 Agent
    ag, err := agent.Create(ctx, &types.AgentConfig{
        TemplateID: "assistant",
        ModelConfig: &types.ModelConfig{
            Provider: "anthropic",
            Model:    "claude-sonnet-4-5",
            APIKey:   os.Getenv("ANTHROPIC_API_KEY"),
        },
        Tools: []string{"get_weather", "query_users", "calculator"},
    }, deps)
    if err != nil {
        log.Fatal(err)
    }
    defer ag.Close()

    // 测试自定义工具
    tasks := []string{
        "查询北京的天气",
        "数据库中有多少用户?",
        "计算 (15 + 25) * 2",
    }

    for i, task := range tasks {
        fmt.Printf("\n========== 任务 %d ==========\n", i+1)
        fmt.Printf("请求: %s\n\n", task)

        result, err := ag.Chat(ctx, task)
        if err != nil {
            log.Printf("错误: %v", err)
            continue
        }

        fmt.Printf("响应: %s\n", result.Message.Content)
    }
}

运行示例

export ANTHROPIC_API_KEY="sk-ant-xxx"
go run main.go

输出示例

========== 任务 1 ==========
请求: 查询北京的天气

响应: 北京当前天气:
- 温度:22.5°C
- 天气:晴朗
- 湿度:65%

========== 任务 2 ==========
请求: 数据库中有多少用户?

响应: 数据库中目前有 42 位用户。

========== 任务 3 ==========
请求: 计算 (15 + 25) * 2

响应: (15 + 25) * 2 = 80

工具开发最佳实践

1. 清晰的描述

Description: "获取指定城市的实时天气信息(温度、天气状况、湿度)"
// ✅ 清晰说明工具功能,包含关键信息

Description: "获取天气"
// ❌ 过于简单,LLM 不知道何时使用

2. 完整的 Schema

InputSchema: map[string]interface{}{
    "type": "object",
    "properties": map[string]interface{}{
        "city": map[string]interface{}{
            "type":        "string",
            "description": "城市名称,例如:北京、上海", // 提供示例
        },
        "unit": map[string]interface{}{
            "type":    "string",
            "enum":    []string{"celsius", "fahrenheit"}, // 限制值
            "default": "celsius",                         // 提供默认值
        },
    },
    "required": []string{"city"}, // 标记必需参数
}

3. 错误处理

Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
    // 验证输入
    city, ok := tc.Input["city"].(string)
    if !ok || city == "" {
        return nil, fmt.Errorf("city 参数必需且不能为空")
    }

    // 检查上下文取消
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    default:
    }

    // 执行操作
    result, err := fetchData(city)
    if err != nil {
        return nil, fmt.Errorf("操作失败: %w", err)
    }

    return result, nil
}

4. 超时控制

Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
    // 设置超时
    ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
    defer cancel()

    // 执行长时间操作
    result, err := longRunningOperation(ctx)
    return result, err
}

5. 资源清理

Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
    // 打开资源
    conn, err := openConnection()
    if err != nil {
        return nil, err
    }
    defer conn.Close() // 确保关闭

    // 使用资源
    result, err := conn.Query()
    return result, err
}

高级示例

异步工具

func AsyncTool() tools.Tool {
    return tools.Tool{
        Name:        "async_task",
        Description: "启动异步任务并返回任务 ID",
        Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
            taskID := uuid.New().String()

            // 在后台执行
            go func() {
                // 执行长时间任务
                time.Sleep(10 * time.Second)
                // 保存结果到数据库或缓存
            }()

            return map[string]interface{}{
                "task_id": taskID,
                "status":  "started",
            }, nil
        },
    }
}

带状态的工具

type StatefulTool struct {
    cache map[string]interface{}
    mu    sync.RWMutex
}

func (t *StatefulTool) Tool() tools.Tool {
    return tools.Tool{
        Name: "cache_get",
        Handler: func(ctx context.Context, tc *tools.ToolContext) (interface{}, error) {
            key := tc.Input["key"].(string)

            t.mu.RLock()
            defer t.mu.RUnlock()

            if value, ok := t.cache[key]; ok {
                return value, nil
            }
            return nil, fmt.Errorf("key not found: %s", key)
        },
    }
}

相关资源