fix(ccm): robust account UUID injection and session ID validation

Replace bytes.Replace-based UUID injection with proper JSON
unmarshal/re-marshal through map[string]json.RawMessage — the old
approach silently failed when the body used non-canonical JSON escaping.

Return 500 when metadata.user_id is present but in an unrecognized
format, instead of silently passing through with an empty session ID.
This commit is contained in:
世界
2026-03-21 11:00:05 +08:00
parent 53f832330d
commit 29b901a8b3
2 changed files with 69 additions and 23 deletions

View File

@@ -851,37 +851,71 @@ func (c *defaultCredential) injectAccountUUID(bodyBytes []byte) []byte {
return bodyBytes
}
var body struct {
Metadata struct {
UserID string `json:"user_id"`
} `json:"metadata"`
var body map[string]json.RawMessage
err := json.Unmarshal(bodyBytes, &body)
if err != nil {
return bodyBytes
}
if json.Unmarshal(bodyBytes, &body) != nil || body.Metadata.UserID == "" {
metadataRaw, hasMetadata := body["metadata"]
if !hasMetadata {
return bodyBytes
}
var userIDObject map[string]any
if json.Unmarshal([]byte(body.Metadata.UserID), &userIDObject) != nil {
var metadata map[string]json.RawMessage
err = json.Unmarshal(metadataRaw, &metadata)
if err != nil {
return bodyBytes
}
existing, _ := userIDObject["account_uuid"].(string)
if existing != "" {
userIDRaw, hasUserID := metadata["user_id"]
if !hasUserID {
return bodyBytes
}
userIDObject["account_uuid"] = accountUUID
newUserID, err := json.Marshal(userIDObject)
var userIDStr string
err = json.Unmarshal(userIDRaw, &userIDStr)
if err != nil || userIDStr == "" {
return bodyBytes
}
var userIDObject map[string]json.RawMessage
err = json.Unmarshal([]byte(userIDStr), &userIDObject)
if err != nil {
return bodyBytes
}
newUserIDStr := string(newUserID)
oldUserIDJSON, err := json.Marshal(body.Metadata.UserID)
existingRaw, hasExisting := userIDObject["account_uuid"]
if hasExisting {
var existing string
if json.Unmarshal(existingRaw, &existing) == nil && existing != "" {
return bodyBytes
}
}
accountUUIDJSON, err := json.Marshal(accountUUID)
if err != nil {
return bodyBytes
}
newUserIDJSON, err := json.Marshal(newUserIDStr)
userIDObject["account_uuid"] = json.RawMessage(accountUUIDJSON)
newUserIDBytes, err := json.Marshal(userIDObject)
if err != nil {
return bodyBytes
}
return bytes.Replace(bodyBytes, oldUserIDJSON, newUserIDJSON, 1)
newUserIDRaw, err := json.Marshal(string(newUserIDBytes))
if err != nil {
return bodyBytes
}
metadata["user_id"] = json.RawMessage(newUserIDRaw)
newMetadataBytes, err := json.Marshal(metadata)
if err != nil {
return bodyBytes
}
body["metadata"] = json.RawMessage(newMetadataBytes)
newBodyBytes, err := json.Marshal(body)
if err != nil {
return bodyBytes
}
return newBodyBytes
}

View File

@@ -82,15 +82,21 @@ func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
// `user_${deviceId}_account_${accountUuid}_session_${sessionId}`
//
// ref: cli.js qs() — old metadata constructor
func extractCCMSessionID(bodyBytes []byte) string {
//
// Returns ("", nil) when body has no metadata.user_id (non-message endpoints).
// Returns error when user_id is present but in an unrecognized format.
func extractCCMSessionID(bodyBytes []byte) (string, error) {
var body struct {
Metadata struct {
Metadata *struct {
UserID string `json:"user_id"`
} `json:"metadata"`
}
err := json.Unmarshal(bodyBytes, &body)
if err != nil {
return ""
return "", nil
}
if body.Metadata == nil || body.Metadata.UserID == "" {
return "", nil
}
userID := body.Metadata.UserID
@@ -99,15 +105,16 @@ func extractCCMSessionID(bodyBytes []byte) string {
SessionID string `json:"session_id"`
}
if json.Unmarshal([]byte(userID), &userIDObject) == nil && userIDObject.SessionID != "" {
return userIDObject.SessionID
return userIDObject.SessionID, nil
}
// legacy template literal format
sessionIndex := strings.LastIndex(userID, "_session_")
if sessionIndex < 0 {
return ""
if sessionIndex >= 0 {
return userID[sessionIndex+len("_session_"):], nil
}
return userID[sessionIndex+len("_session_"):]
return "", E.New("unrecognized metadata.user_id format: ", userID)
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -181,7 +188,12 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
messagesCount = len(request.Messages)
}
sessionID = extractCCMSessionID(bodyBytes)
sessionID, err = extractCCMSessionID(bodyBytes)
if err != nil {
s.logger.ErrorContext(ctx, "invalid metadata format: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "invalid metadata format")
return
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}