mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
556 lines
14 KiB
Go
556 lines
14 KiB
Go
package ocm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"mime"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
boxService "github.com/sagernet/sing-box/adapter/service"
|
|
"github.com/sagernet/sing-box/common/dialer"
|
|
"github.com/sagernet/sing-box/common/listener"
|
|
"github.com/sagernet/sing-box/common/tls"
|
|
C "github.com/sagernet/sing-box/constant"
|
|
"github.com/sagernet/sing-box/log"
|
|
"github.com/sagernet/sing-box/option"
|
|
"github.com/sagernet/sing/common"
|
|
"github.com/sagernet/sing/common/buf"
|
|
E "github.com/sagernet/sing/common/exceptions"
|
|
M "github.com/sagernet/sing/common/metadata"
|
|
N "github.com/sagernet/sing/common/network"
|
|
aTLS "github.com/sagernet/sing/common/tls"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/responses"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
func RegisterService(registry *boxService.Registry) {
|
|
boxService.Register[option.OCMServiceOptions](registry, C.TypeOCM, NewService)
|
|
}
|
|
|
|
type errorResponse struct {
|
|
Error errorDetails `json:"error"`
|
|
}
|
|
|
|
type errorDetails struct {
|
|
Type string `json:"type"`
|
|
Code string `json:"code,omitempty"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
|
|
json.NewEncoder(w).Encode(errorResponse{
|
|
Error: errorDetails{
|
|
Type: errorType,
|
|
Message: message,
|
|
},
|
|
})
|
|
}
|
|
|
|
func isHopByHopHeader(header string) bool {
|
|
switch strings.ToLower(header) {
|
|
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
type Service struct {
|
|
boxService.Adapter
|
|
ctx context.Context
|
|
logger log.ContextLogger
|
|
credentialPath string
|
|
credentials *oauthCredentials
|
|
users []option.OCMUser
|
|
httpClient *http.Client
|
|
httpHeaders http.Header
|
|
listener *listener.Listener
|
|
tlsConfig tls.ServerConfig
|
|
httpServer *http.Server
|
|
userManager *UserManager
|
|
accessMutex sync.RWMutex
|
|
usageTracker *AggregatedUsage
|
|
trackingGroup sync.WaitGroup
|
|
shuttingDown bool
|
|
}
|
|
|
|
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
|
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
|
Context: ctx,
|
|
Options: option.DialerOptions{
|
|
Detour: options.Detour,
|
|
},
|
|
RemoteIsDomain: true,
|
|
})
|
|
if err != nil {
|
|
return nil, E.Cause(err, "create dialer")
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Transport: &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
|
},
|
|
},
|
|
}
|
|
|
|
userManager := &UserManager{
|
|
tokenMap: make(map[string]string),
|
|
}
|
|
|
|
var usageTracker *AggregatedUsage
|
|
if options.UsagesPath != "" {
|
|
usageTracker = &AggregatedUsage{
|
|
LastUpdated: time.Now(),
|
|
Combinations: make([]CostCombination, 0),
|
|
filePath: options.UsagesPath,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
service := &Service{
|
|
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
|
ctx: ctx,
|
|
logger: logger,
|
|
credentialPath: options.CredentialPath,
|
|
users: options.Users,
|
|
httpClient: httpClient,
|
|
httpHeaders: options.Headers.Build(),
|
|
listener: listener.New(listener.Options{
|
|
Context: ctx,
|
|
Logger: logger,
|
|
Network: []string{N.NetworkTCP},
|
|
Listen: options.ListenOptions,
|
|
}),
|
|
userManager: userManager,
|
|
usageTracker: usageTracker,
|
|
}
|
|
|
|
if options.TLS != nil {
|
|
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
service.tlsConfig = tlsConfig
|
|
}
|
|
|
|
return service, nil
|
|
}
|
|
|
|
func (s *Service) Start(stage adapter.StartStage) error {
|
|
if stage != adapter.StartStateStart {
|
|
return nil
|
|
}
|
|
|
|
s.userManager.UpdateUsers(s.users)
|
|
|
|
credentials, err := platformReadCredentials(s.credentialPath)
|
|
if err != nil {
|
|
return E.Cause(err, "read credentials")
|
|
}
|
|
s.credentials = credentials
|
|
|
|
if s.usageTracker != nil {
|
|
err = s.usageTracker.Load()
|
|
if err != nil {
|
|
s.logger.Warn("load usage statistics: ", err)
|
|
}
|
|
}
|
|
|
|
router := chi.NewRouter()
|
|
router.Mount("/", s)
|
|
|
|
s.httpServer = &http.Server{Handler: router}
|
|
|
|
if s.tlsConfig != nil {
|
|
err = s.tlsConfig.Start()
|
|
if err != nil {
|
|
return E.Cause(err, "create TLS config")
|
|
}
|
|
}
|
|
|
|
tcpListener, err := s.listener.ListenTCP()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if s.tlsConfig != nil {
|
|
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
|
|
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
|
|
}
|
|
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
|
|
}
|
|
|
|
go func() {
|
|
serveErr := s.httpServer.Serve(tcpListener)
|
|
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
|
s.logger.Error("serve error: ", serveErr)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) getAccessToken() (string, error) {
|
|
s.accessMutex.RLock()
|
|
if !s.credentials.needsRefresh() {
|
|
token := s.credentials.getAccessToken()
|
|
s.accessMutex.RUnlock()
|
|
return token, nil
|
|
}
|
|
s.accessMutex.RUnlock()
|
|
|
|
s.accessMutex.Lock()
|
|
defer s.accessMutex.Unlock()
|
|
|
|
if !s.credentials.needsRefresh() {
|
|
return s.credentials.getAccessToken(), nil
|
|
}
|
|
|
|
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
s.credentials = newCredentials
|
|
|
|
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
|
if err != nil {
|
|
s.logger.Warn("persist refreshed token: ", err)
|
|
}
|
|
|
|
return newCredentials.getAccessToken(), nil
|
|
}
|
|
|
|
func (s *Service) getAccountID() string {
|
|
s.accessMutex.RLock()
|
|
defer s.accessMutex.RUnlock()
|
|
return s.credentials.getAccountID()
|
|
}
|
|
|
|
func (s *Service) isAPIKeyMode() bool {
|
|
s.accessMutex.RLock()
|
|
defer s.accessMutex.RUnlock()
|
|
return s.credentials.isAPIKeyMode()
|
|
}
|
|
|
|
func (s *Service) getBaseURL() string {
|
|
if s.isAPIKeyMode() {
|
|
return openaiAPIBaseURL
|
|
}
|
|
return chatGPTBackendURL
|
|
}
|
|
|
|
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
path := r.URL.Path
|
|
if !strings.HasPrefix(path, "/v1/") {
|
|
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
|
return
|
|
}
|
|
|
|
var proxyPath string
|
|
if s.isAPIKeyMode() {
|
|
proxyPath = path
|
|
} else {
|
|
if path == "/v1/chat/completions" {
|
|
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
|
"chat completions endpoint is only available in API key mode")
|
|
return
|
|
}
|
|
proxyPath = strings.TrimPrefix(path, "/v1")
|
|
}
|
|
|
|
var username string
|
|
if len(s.users) > 0 {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
|
return
|
|
}
|
|
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if clientToken == authHeader {
|
|
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
|
return
|
|
}
|
|
var ok bool
|
|
username, ok = s.userManager.Authenticate(clientToken)
|
|
if !ok {
|
|
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
|
return
|
|
}
|
|
}
|
|
|
|
var requestModel string
|
|
|
|
if s.usageTracker != nil && r.Body != nil {
|
|
bodyBytes, err := io.ReadAll(r.Body)
|
|
if err == nil {
|
|
var request struct {
|
|
Model string `json:"model"`
|
|
}
|
|
err := json.Unmarshal(bodyBytes, &request)
|
|
if err == nil {
|
|
requestModel = request.Model
|
|
}
|
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
}
|
|
}
|
|
|
|
accessToken, err := s.getAccessToken()
|
|
if err != nil {
|
|
s.logger.Error("get access token: ", err)
|
|
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
|
return
|
|
}
|
|
|
|
proxyURL := s.getBaseURL() + proxyPath
|
|
if r.URL.RawQuery != "" {
|
|
proxyURL += "?" + r.URL.RawQuery
|
|
}
|
|
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
|
if err != nil {
|
|
s.logger.Error("create proxy request: ", err)
|
|
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
|
return
|
|
}
|
|
|
|
for key, values := range r.Header {
|
|
if !isHopByHopHeader(key) && key != "Authorization" {
|
|
proxyRequest.Header[key] = values
|
|
}
|
|
}
|
|
|
|
for key, values := range s.httpHeaders {
|
|
proxyRequest.Header.Del(key)
|
|
proxyRequest.Header[key] = values
|
|
}
|
|
|
|
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
|
|
|
if accountID := s.getAccountID(); accountID != "" {
|
|
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
|
}
|
|
|
|
response, err := s.httpClient.Do(proxyRequest)
|
|
if err != nil {
|
|
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
|
return
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
for key, values := range response.Header {
|
|
if !isHopByHopHeader(key) {
|
|
w.Header()[key] = values
|
|
}
|
|
}
|
|
w.WriteHeader(response.StatusCode)
|
|
|
|
trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
|
|
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
|
if trackUsage {
|
|
s.handleResponseWithTracking(w, response, path, requestModel, username)
|
|
} else {
|
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
|
if err == nil && mediaType != "text/event-stream" {
|
|
_, _ = io.Copy(w, response.Body)
|
|
return
|
|
}
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
s.logger.Error("streaming not supported")
|
|
return
|
|
}
|
|
buffer := make([]byte, buf.BufferSize)
|
|
for {
|
|
n, err := response.Body.Read(buffer)
|
|
if n > 0 {
|
|
_, writeError := w.Write(buffer[:n])
|
|
if writeError != nil {
|
|
s.logger.Error("write streaming response: ", writeError)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
|
|
isChatCompletions := path == "/v1/chat/completions"
|
|
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
|
isStreaming := err == nil && mediaType == "text/event-stream"
|
|
|
|
if !isStreaming {
|
|
bodyBytes, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
s.logger.Error("read response body: ", err)
|
|
return
|
|
}
|
|
|
|
var responseModel string
|
|
var inputTokens, outputTokens, cachedTokens int64
|
|
|
|
if isChatCompletions {
|
|
var chatCompletion openai.ChatCompletion
|
|
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
|
responseModel = chatCompletion.Model
|
|
inputTokens = chatCompletion.Usage.PromptTokens
|
|
outputTokens = chatCompletion.Usage.CompletionTokens
|
|
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
|
}
|
|
} else {
|
|
var responsesResponse responses.Response
|
|
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
|
responseModel = string(responsesResponse.Model)
|
|
inputTokens = responsesResponse.Usage.InputTokens
|
|
outputTokens = responsesResponse.Usage.OutputTokens
|
|
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
}
|
|
|
|
if inputTokens > 0 || outputTokens > 0 {
|
|
if responseModel == "" {
|
|
responseModel = requestModel
|
|
}
|
|
if responseModel != "" {
|
|
s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
|
|
}
|
|
}
|
|
|
|
_, _ = writer.Write(bodyBytes)
|
|
return
|
|
}
|
|
|
|
flusher, ok := writer.(http.Flusher)
|
|
if !ok {
|
|
s.logger.Error("streaming not supported")
|
|
return
|
|
}
|
|
|
|
var inputTokens, outputTokens, cachedTokens int64
|
|
var responseModel string
|
|
buffer := make([]byte, buf.BufferSize)
|
|
var leftover []byte
|
|
|
|
for {
|
|
n, err := response.Body.Read(buffer)
|
|
if n > 0 {
|
|
data := append(leftover, buffer[:n]...)
|
|
lines := bytes.Split(data, []byte("\n"))
|
|
|
|
if err == nil {
|
|
leftover = lines[len(lines)-1]
|
|
lines = lines[:len(lines)-1]
|
|
} else {
|
|
leftover = nil
|
|
}
|
|
|
|
for _, line := range lines {
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
|
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
|
if bytes.Equal(eventData, []byte("[DONE]")) {
|
|
continue
|
|
}
|
|
|
|
if isChatCompletions {
|
|
var chatChunk openai.ChatCompletionChunk
|
|
if json.Unmarshal(eventData, &chatChunk) == nil {
|
|
if chatChunk.Model != "" {
|
|
responseModel = chatChunk.Model
|
|
}
|
|
if chatChunk.Usage.PromptTokens > 0 {
|
|
inputTokens = chatChunk.Usage.PromptTokens
|
|
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
|
}
|
|
if chatChunk.Usage.CompletionTokens > 0 {
|
|
outputTokens = chatChunk.Usage.CompletionTokens
|
|
}
|
|
}
|
|
} else {
|
|
var streamEvent responses.ResponseStreamEventUnion
|
|
if json.Unmarshal(eventData, &streamEvent) == nil {
|
|
if streamEvent.Type == "response.completed" {
|
|
completedEvent := streamEvent.AsResponseCompleted()
|
|
if string(completedEvent.Response.Model) != "" {
|
|
responseModel = string(completedEvent.Response.Model)
|
|
}
|
|
if completedEvent.Response.Usage.InputTokens > 0 {
|
|
inputTokens = completedEvent.Response.Usage.InputTokens
|
|
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
|
}
|
|
if completedEvent.Response.Usage.OutputTokens > 0 {
|
|
outputTokens = completedEvent.Response.Usage.OutputTokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
_, writeError := writer.Write(buffer[:n])
|
|
if writeError != nil {
|
|
s.logger.Error("write streaming response: ", writeError)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
|
|
if err != nil {
|
|
if responseModel == "" {
|
|
responseModel = requestModel
|
|
}
|
|
|
|
if inputTokens > 0 || outputTokens > 0 {
|
|
if responseModel != "" {
|
|
s.usageTracker.AddUsage(responseModel, inputTokens, outputTokens, cachedTokens, username)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) Close() error {
|
|
err := common.Close(
|
|
common.PtrOrNil(s.httpServer),
|
|
common.PtrOrNil(s.listener),
|
|
s.tlsConfig,
|
|
)
|
|
|
|
if s.usageTracker != nil {
|
|
s.usageTracker.cancelPendingSave()
|
|
saveErr := s.usageTracker.Save()
|
|
if saveErr != nil {
|
|
s.logger.Error("save usage statistics: ", saveErr)
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|