package tools import ( "context" "fmt" "math" "strconv" "strings" "unicode" "github.com/yourname/cyrene-ai/tool-engine/internal/model" ) // CalculatorTool performs safe mathematical expression evaluation. type CalculatorTool struct{} // NewCalculatorTool creates a calculator tool. func NewCalculatorTool() *CalculatorTool { return &CalculatorTool{} } // Definition returns the tool definition for LLM function calling. func (t *CalculatorTool) Definition() model.ToolDefinition { return model.ToolDefinition{ Name: "calculator", Description: "执行数学计算。用于精确计算数学表达式,支持四则运算、三角函数、对数、幂运算等。适用于LLM不擅长的复杂计算场景。", Parameters: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "expression": map[string]interface{}{ "type": "string", "description": "数学表达式,如 \"2 + 3 * 4\"、\"sqrt(16) + sin(pi/2)\"。支持运算符: + - * / % ^。支持函数: sqrt, sin, cos, tan, abs, floor, ceil, round, log, ln, pow。支持常量: pi, e。", }, }, "required": []string{"expression"}, }, } } // Execute evaluates a mathematical expression. func (t *CalculatorTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) { expression, ok := arguments["expression"].(string) if !ok || strings.TrimSpace(expression) == "" { return &model.ToolResult{ ID: "", Error: "缺少 expression 参数", }, nil } result, err := evaluate(expression) if err != nil { return &model.ToolResult{ ID: "", Error: fmt.Sprintf("计算错误: %v", err), }, nil } return &model.ToolResult{ ID: "", Output: fmt.Sprintf("表达式: %s\n结果: %s", expression, formatResult(result)), }, nil } func formatResult(v float64) string { if v == math.Trunc(v) && math.Abs(v) < 1e15 { return strconv.FormatInt(int64(v), 10) } return strconv.FormatFloat(v, 'g', -1, 64) } type tokenKind int const ( tokNumber tokenKind = iota tokIdent tokOp tokLParen tokRParen tokComma tokEOF ) type token struct { kind tokenKind value string } type lexer struct { input []rune pos int } func newLexer(s string) *lexer { return &lexer{input: []rune(s), pos: 0} } func (l *lexer) next() token { l.skipWhitespace() if l.pos >= len(l.input) { return token{kind: tokEOF} } ch := l.input[l.pos] if unicode.IsDigit(ch) || ch == '.' { start := l.pos hasDot := ch == '.' l.pos++ for l.pos < len(l.input) && (unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '.') { if l.input[l.pos] == '.' { if hasDot { break } hasDot = true } l.pos++ } return token{kind: tokNumber, value: string(l.input[start:l.pos])} } if unicode.IsLetter(ch) || ch == '_' { start := l.pos l.pos++ for l.pos < len(l.input) && (unicode.IsLetter(l.input[l.pos]) || unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '_') { l.pos++ } return token{kind: tokIdent, value: string(l.input[start:l.pos])} } switch ch { case '+', '-', '*', '/', '%', '^': l.pos++ return token{kind: tokOp, value: string(ch)} case '(': l.pos++ return token{kind: tokLParen} case ')': l.pos++ return token{kind: tokRParen} case ',': l.pos++ return token{kind: tokComma} } return token{kind: tokEOF} } func (l *lexer) skipWhitespace() { for l.pos < len(l.input) && unicode.IsSpace(l.input[l.pos]) { l.pos++ } } type parser struct { lex *lexer cur token peek token } func newParser(lex *lexer) *parser { p := &parser{lex: lex} p.cur = lex.next() p.peek = lex.next() return p } func (p *parser) advance() { p.cur = p.peek p.peek = p.lex.next() } func evaluate(expr string) (float64, error) { lex := newLexer(expr) par := newParser(lex) result, err := par.parseExpression() if err != nil { return 0, err } if par.cur.kind != tokEOF { return 0, fmt.Errorf("表达式末尾存在意外字符") } return result, nil } func (p *parser) parseExpression() (float64, error) { left, err := p.parseTerm() if err != nil { return 0, err } for p.cur.kind == tokOp && (p.cur.value == "+" || p.cur.value == "-") { op := p.cur.value p.advance() right, err := p.parseTerm() if err != nil { return 0, err } if op == "+" { left += right } else { left -= right } } return left, nil } func (p *parser) parseTerm() (float64, error) { left, err := p.parseUnary() if err != nil { return 0, err } for p.cur.kind == tokOp && (p.cur.value == "*" || p.cur.value == "/" || p.cur.value == "%" || p.cur.value == "^") { op := p.cur.value p.advance() right, err := p.parseUnary() if err != nil { return 0, err } switch op { case "*": left *= right case "/": if right == 0 { return 0, fmt.Errorf("除数不能为零") } left /= right case "%": left = math.Mod(left, right) case "^": left = math.Pow(left, right) } } return left, nil } func (p *parser) parseUnary() (float64, error) { if p.cur.kind == tokOp && p.cur.value == "-" { p.advance() val, err := p.parseUnary() if err != nil { return 0, err } return -val, nil } if p.cur.kind == tokOp && p.cur.value == "+" { p.advance() return p.parseUnary() } return p.parseAtom() } func (p *parser) parseAtom() (float64, error) { switch p.cur.kind { case tokNumber: val, err := strconv.ParseFloat(p.cur.value, 64) if err != nil { return 0, fmt.Errorf("无效数字: %s", p.cur.value) } p.advance() return val, nil case tokIdent: name := strings.ToLower(p.cur.value) p.advance() switch name { case "pi": return math.Pi, nil case "e": return math.E, nil } if p.cur.kind != tokLParen { return 0, fmt.Errorf("未知标识符: %s (如果是函数需要加括号)", name) } p.advance() arg, err := p.parseExpression() if err != nil { return 0, err } if p.cur.kind != tokRParen { return 0, fmt.Errorf("函数 %s 缺少右括号", name) } p.advance() return applyFunc(name, arg) case tokLParen: p.advance() val, err := p.parseExpression() if err != nil { return 0, err } if p.cur.kind != tokRParen { return 0, fmt.Errorf("缺少右括号") } p.advance() return val, nil default: return 0, fmt.Errorf("意外的 token: %v", p.cur.value) } } func applyFunc(name string, arg float64) (float64, error) { switch name { case "sqrt": if arg < 0 { return 0, fmt.Errorf("sqrt 参数不能为负数") } return math.Sqrt(arg), nil case "sin": return math.Sin(arg), nil case "cos": return math.Cos(arg), nil case "tan": return math.Tan(arg), nil case "abs": return math.Abs(arg), nil case "floor": return math.Floor(arg), nil case "ceil": return math.Ceil(arg), nil case "round": return math.Round(arg), nil case "log": if arg <= 0 { return 0, fmt.Errorf("log 参数必须大于0") } return math.Log10(arg), nil case "ln": if arg <= 0 { return 0, fmt.Errorf("ln 参数必须大于0") } return math.Log(arg), nil case "pow": return 0, fmt.Errorf("pow 需要两个参数,请使用 ^ 运算符代替") default: return 0, fmt.Errorf("未知函数: %s", name) } }