mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-08 01:33:31 +08:00
120 lines
3.4 KiB
Go
120 lines
3.4 KiB
Go
|
|
package service
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/chaitin/MonkeyCode/backend/config"
|
||
|
|
"github.com/chaitin/MonkeyCode/backend/domain"
|
||
|
|
)
|
||
|
|
|
||
|
|
// EmbeddingService 定义向量嵌入服务接口
|
||
|
|
type EmbeddingService interface {
|
||
|
|
// GenerateEmbedding 为代码片段生成向量嵌入
|
||
|
|
GenerateEmbedding(ctx context.Context, snippet *domain.CodeSnippet) ([]float32, error)
|
||
|
|
|
||
|
|
// GenerateEmbeddingFromContent 为代码内容生成向量嵌入
|
||
|
|
GenerateEmbeddingFromContent(ctx context.Context, content string) ([]float32, error)
|
||
|
|
}
|
||
|
|
|
||
|
|
// OpenAIEmbeddingService 实现向量嵌入服务
|
||
|
|
type OpenAIEmbeddingService struct {
|
||
|
|
config *config.Config
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewOpenAIEmbeddingService 创建新的向量嵌入服务实例
|
||
|
|
func NewOpenAIEmbeddingService(cfg *config.Config) EmbeddingService {
|
||
|
|
return &OpenAIEmbeddingService{
|
||
|
|
config: cfg,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateEmbedding 为代码片段生成向量嵌入
|
||
|
|
func (s *OpenAIEmbeddingService) GenerateEmbedding(ctx context.Context, snippet *domain.CodeSnippet) ([]float32, error) {
|
||
|
|
// 只使用content生成embedding
|
||
|
|
content := snippet.Content
|
||
|
|
return s.GenerateEmbeddingFromContent(ctx, content)
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateEmbeddingFromContent 为代码内容生成向量嵌入
|
||
|
|
func (s *OpenAIEmbeddingService) GenerateEmbeddingFromContent(ctx context.Context, content string) ([]float32, error) {
|
||
|
|
// 检查配置是否完整
|
||
|
|
if s.config.Embedding.APIEndpoint == "" {
|
||
|
|
return nil, fmt.Errorf("embedding API endpoint not configured")
|
||
|
|
}
|
||
|
|
if s.config.Embedding.APIKey == "" {
|
||
|
|
return nil, fmt.Errorf("embedding API key not configured")
|
||
|
|
}
|
||
|
|
if s.config.Embedding.ModelName == "" {
|
||
|
|
return nil, fmt.Errorf("embedding model name not configured")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 构造请求体
|
||
|
|
requestBody := map[string]interface{}{
|
||
|
|
"input": content,
|
||
|
|
"model": s.config.Embedding.ModelName,
|
||
|
|
}
|
||
|
|
|
||
|
|
// 将请求体序列化为JSON
|
||
|
|
jsonBody, err := json.Marshal(requestBody)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 创建HTTP请求
|
||
|
|
req, err := http.NewRequestWithContext(ctx, "POST", s.config.Embedding.APIEndpoint, bytes.NewBuffer(jsonBody))
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 设置请求头
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
req.Header.Set("Authorization", "Bearer "+s.config.Embedding.APIKey)
|
||
|
|
|
||
|
|
// 创建带超时的HTTP客户端
|
||
|
|
client := &http.Client{
|
||
|
|
Timeout: 30 * time.Second, // 30秒超时
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to send HTTP request: %w", err)
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
// 读取响应体
|
||
|
|
body, err := io.ReadAll(resp.Body)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查响应状态码
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
return nil, fmt.Errorf("API request failed with status code %d: %s", resp.StatusCode, string(body))
|
||
|
|
}
|
||
|
|
|
||
|
|
// 解析响应
|
||
|
|
var embeddingResponse struct {
|
||
|
|
Data []struct {
|
||
|
|
Embedding []float32 `json:"embedding"`
|
||
|
|
} `json:"data"`
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := json.Unmarshal(body, &embeddingResponse); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 检查是否有嵌入向量返回
|
||
|
|
if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 {
|
||
|
|
return nil, fmt.Errorf("no embedding data returned from API")
|
||
|
|
}
|
||
|
|
|
||
|
|
// 直接返回float32数组
|
||
|
|
return embeddingResponse.Data[0].Embedding, nil
|
||
|
|
}
|