Use mcdsl/terminal for all password prompts

Replace direct golang.org/x/term calls with mcdsl/terminal.ReadPassword
across mciasctl (6 sites), mciasgrpcctl (1 site), and mciasdb (1 site).
Aligns with the new CLI security standard in engineering-standards.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-28 11:40:11 -07:00
parent e4220b840e
commit ecbfe1bd66
142 changed files with 10241 additions and 7788 deletions

View File

@@ -59,7 +59,8 @@ import (
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"git.wntrmute.dev/mc/mcdsl/terminal"
) )
// Global flags bound by the root command's PersistentFlags. // Global flags bound by the root command's PersistentFlags.
@@ -139,7 +140,7 @@ func authCmd() *cobra.Command {
// appearing in shell history, ps output, and process argument lists. // appearing in shell history, ps output, and process argument lists.
// //
// Security: terminal echo is disabled during password entry // Security: terminal echo is disabled during password entry
// (golang.org/x/term.ReadPassword); the raw byte slice is zeroed after use. // (mcdsl/terminal.ReadPassword).
func authLoginCmd() *cobra.Command { func authLoginCmd() *cobra.Command {
var username string var username string
var totpCode string var totpCode string
@@ -161,17 +162,10 @@ Example: export MCIAS_TOKEN=$(mciasctl auth login --username alice)`,
// Security: always prompt interactively; never accept password as a flag. // Security: always prompt interactively; never accept password as a flag.
// This prevents the credential from appearing in shell history, ps output, // This prevents the credential from appearing in shell history, ps output,
// and /proc/PID/cmdline. // and /proc/PID/cmdline.
fmt.Fprint(os.Stderr, "Password: ") passwd, err := terminal.ReadPassword("Password: ")
raw, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms
fmt.Fprintln(os.Stderr) // newline after hidden input
if err != nil { if err != nil {
fatalf("read password: %v", err) fatalf("read password: %v", err)
} }
passwd := string(raw)
// Zero the raw byte slice once copied into the string.
for i := range raw {
raw[i] = 0
}
body := map[string]string{ body := map[string]string{
"username": username, "username": username,
@@ -206,10 +200,10 @@ Example: export MCIAS_TOKEN=$(mciasctl auth login --username alice)`,
// command-line flags to prevent them from appearing in shell history, ps // command-line flags to prevent them from appearing in shell history, ps
// output, and process argument lists. // output, and process argument lists.
// //
// Security: terminal echo is disabled during entry (golang.org/x/term); // Security: terminal echo is disabled during entry
// raw byte slices are zeroed after use. The server requires the current // (mcdsl/terminal.ReadPassword). The server requires the current password
// password to prevent token-theft attacks. On success all other active // to prevent token-theft attacks. On success all other active sessions are
// sessions are revoked server-side. // revoked server-side.
func authChangePasswordCmd() *cobra.Command { func authChangePasswordCmd() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "change-password", Use: "change-password",
@@ -221,27 +215,15 @@ Revokes all other active sessions on success.`,
c := newController() c := newController()
// Security: always prompt interactively; never accept passwords as flags. // Security: always prompt interactively; never accept passwords as flags.
fmt.Fprint(os.Stderr, "Current password: ") currentPasswd, err := terminal.ReadPassword("Current password: ")
rawCurrent, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read current password: %v", err) fatalf("read current password: %v", err)
} }
currentPasswd := string(rawCurrent)
for i := range rawCurrent {
rawCurrent[i] = 0
}
fmt.Fprint(os.Stderr, "New password: ") newPasswd, err := terminal.ReadPassword("New password: ")
rawNew, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read new password: %v", err) fatalf("read new password: %v", err)
} }
newPasswd := string(rawNew)
for i := range rawNew {
rawNew[i] = 0
}
body := map[string]string{ body := map[string]string{
"current_password": currentPasswd, "current_password": currentPasswd,
@@ -297,20 +279,15 @@ func accountCreateCmd() *cobra.Command {
c := newController() c := newController()
// Security: always prompt interactively for human-account passwords; never // Security: always prompt interactively for human-account passwords; never
// accept them as a flag. Terminal echo is disabled; the raw byte slice is // accept them as a flag. Terminal echo is disabled via
// zeroed after conversion to string. System accounts have no password. // mcdsl/terminal.ReadPassword. System accounts have no password.
var passwd string var passwd string
if accountType == "human" { if accountType == "human" {
fmt.Fprint(os.Stderr, "Password: ") var err error
raw, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms passwd, err = terminal.ReadPassword("Password: ")
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read password: %v", err) fatalf("read password: %v", err)
} }
passwd = string(raw)
for i := range raw {
raw[i] = 0
}
} }
body := map[string]string{ body := map[string]string{
@@ -405,7 +382,7 @@ func accountDeleteCmd() *cobra.Command {
// Security: the new password is always prompted interactively; it is never // Security: the new password is always prompted interactively; it is never
// accepted as a command-line flag to prevent it from appearing in shell // accepted as a command-line flag to prevent it from appearing in shell
// history, ps output, and process argument lists. Terminal echo is disabled // history, ps output, and process argument lists. Terminal echo is disabled
// (golang.org/x/term); the raw byte slice is zeroed after use. // (mcdsl/terminal.ReadPassword).
func accountSetPasswordCmd() *cobra.Command { func accountSetPasswordCmd() *cobra.Command {
var id string var id string
@@ -423,16 +400,10 @@ Revokes all active sessions for the account.`,
c := newController() c := newController()
// Security: always prompt interactively; never accept password as a flag. // Security: always prompt interactively; never accept password as a flag.
fmt.Fprint(os.Stderr, "New password: ") passwd, err := terminal.ReadPassword("New password: ")
raw, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read password: %v", err) fatalf("read password: %v", err)
} }
passwd := string(raw)
for i := range raw {
raw[i] = 0
}
body := map[string]string{"new_password": passwd} body := map[string]string{"new_password": passwd}
c.doRequest("PUT", "/v1/accounts/"+id+"/password", body, nil) c.doRequest("PUT", "/v1/accounts/"+id+"/password", body, nil)
@@ -684,20 +655,15 @@ func pgcredsSetCmd() *cobra.Command {
// Prompt for the Postgres password interactively if not supplied so it // Prompt for the Postgres password interactively if not supplied so it
// stays out of shell history. // stays out of shell history.
// Security: terminal echo is disabled during entry; the raw byte slice is // Security: terminal echo is disabled during entry via
// zeroed after conversion to string. // mcdsl/terminal.ReadPassword.
passwd := password passwd := password
if passwd == "" { if passwd == "" {
fmt.Fprint(os.Stderr, "Postgres password: ") var err error
raw, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms passwd, err = terminal.ReadPassword("Postgres password: ")
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read password: %v", err) fatalf("read password: %v", err)
} }
passwd = string(raw)
for i := range raw {
raw[i] = 0
}
} }
body := map[string]interface{}{ body := map[string]interface{}{

View File

@@ -8,7 +8,8 @@ import (
"git.wntrmute.dev/mc/mcias/internal/auth" "git.wntrmute.dev/mc/mcias/internal/auth"
"git.wntrmute.dev/mc/mcias/internal/model" "git.wntrmute.dev/mc/mcias/internal/model"
"golang.org/x/term"
"git.wntrmute.dev/mc/mcdsl/terminal"
) )
func (t *tool) runAccount(args []string) { func (t *tool) runAccount(args []string) {
@@ -233,20 +234,14 @@ func (t *tool) accountResetTOTP(args []string) {
// readPassword reads a password from the terminal without echo. // readPassword reads a password from the terminal without echo.
// Falls back to a regular line read if stdin is not a terminal (e.g. in tests). // Falls back to a regular line read if stdin is not a terminal (e.g. in tests).
func readPassword(prompt string) (string, error) { func readPassword(prompt string) (string, error) {
pw, err := terminal.ReadPassword(prompt)
if err == nil {
return pw, nil
}
// Fallback for piped input (e.g. tests).
fmt.Fprint(os.Stderr, prompt) fmt.Fprint(os.Stderr, prompt)
fd := int(os.Stdin.Fd()) //nolint:gosec // G115: file descriptors are non-negative and fit in int on all supported platforms
if term.IsTerminal(fd) {
pw, err := term.ReadPassword(fd)
fmt.Fprintln(os.Stderr) // newline after hidden input
if err != nil {
return "", fmt.Errorf("read password from terminal: %w", err)
}
return string(pw), nil
}
// Not a terminal: read a plain line (for piped input in tests).
var line string var line string
_, err := fmt.Fscanln(os.Stdin, &line) if _, err := fmt.Fscanln(os.Stdin, &line); err != nil {
if err != nil {
return "", fmt.Errorf("read password: %w", err) return "", fmt.Errorf("read password: %w", err)
} }
return line, nil return line, nil

View File

@@ -59,11 +59,11 @@ import (
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"git.wntrmute.dev/mc/mcdsl/terminal"
mciasv1 "git.wntrmute.dev/mc/mcias/gen/mcias/v1" mciasv1 "git.wntrmute.dev/mc/mcias/gen/mcias/v1"
) )
@@ -213,7 +213,7 @@ func authCmd(ctl *controller) *cobra.Command {
// lists. // lists.
// //
// Security: terminal echo is disabled during password entry // Security: terminal echo is disabled during password entry
// (golang.org/x/term.ReadPassword); the raw byte slice is zeroed after use. // (mcdsl/terminal.ReadPassword).
func authLoginCmd(ctl *controller) *cobra.Command { func authLoginCmd(ctl *controller) *cobra.Command {
var ( var (
username string username string
@@ -230,17 +230,10 @@ func authLoginCmd(ctl *controller) *cobra.Command {
// Security: always prompt interactively; never accept password as a flag. // Security: always prompt interactively; never accept password as a flag.
// This prevents the credential from appearing in shell history, ps output, // This prevents the credential from appearing in shell history, ps output,
// and /proc/PID/cmdline. // and /proc/PID/cmdline.
fmt.Fprint(os.Stderr, "Password: ") passwd, err := terminal.ReadPassword("Password: ")
raw, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // uintptr==int on all target platforms
fmt.Fprintln(os.Stderr)
if err != nil { if err != nil {
fatalf("read password: %v", err) fatalf("read password: %v", err)
} }
passwd := string(raw)
// Zero the raw byte slice once copied into the string.
for i := range raw {
raw[i] = 0
}
authCl := mciasv1.NewAuthServiceClient(ctl.conn) authCl := mciasv1.NewAuthServiceClient(ctl.conn)
// Login is a public RPC — no auth context needed. // Login is a public RPC — no auth context needed.

18
go.mod
View File

@@ -3,17 +3,18 @@ module git.wntrmute.dev/mc/mcias
go 1.26.0 go 1.26.0
require ( require (
git.wntrmute.dev/mc/mcdsl v1.4.0
github.com/go-webauthn/webauthn v0.16.1 github.com/go-webauthn/webauthn v0.16.1
github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang-migrate/migrate/v4 v4.19.1 github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/pelletier/go-toml/v2 v2.2.4 github.com/pelletier/go-toml/v2 v2.3.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/spf13/cobra v1.10.2
golang.org/x/crypto v0.49.0 golang.org/x/crypto v0.49.0
golang.org/x/term v0.41.0 google.golang.org/grpc v1.79.3
google.golang.org/grpc v1.74.2 google.golang.org/protobuf v1.36.10
google.golang.org/protobuf v1.36.7 modernc.org/sqlite v1.47.0
modernc.org/sqlite v1.46.1
) )
require ( require (
@@ -26,15 +27,14 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/spf13/cobra v1.10.2 // indirect
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
github.com/x448/float16 v0.8.4 // indirect github.com/x448/float16 v0.8.4 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/net v0.51.0 // indirect golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/term v0.41.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect
) )

68
go.sum
View File

@@ -1,3 +1,7 @@
git.wntrmute.dev/mc/mcdsl v1.4.0 h1:PsEIyskcjBduwHSRwNB/U/uSeU/cv3C8MVr0SRjBRLg=
git.wntrmute.dev/mc/mcdsl v1.4.0/go.mod h1:MhYahIu7Sg53lE2zpQ20nlrsoNRjQzOJBAlCmom2wJc=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -41,8 +45,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@@ -58,25 +62,23 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
@@ -92,29 +94,31 @@ golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c h1:qXWI/sQtv5UKboZ/zUk7h+mrf/lXORyI+n9DKDAusdg= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -123,8 +127,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -0,0 +1,36 @@
// Package terminal provides secure terminal input helpers for CLI tools.
package terminal
import (
"fmt"
"os"
"golang.org/x/term"
)
// ReadPassword prints the given prompt to stderr and reads a password
// from the terminal with echo disabled. It prints a newline after the
// input is complete so the cursor advances normally.
func ReadPassword(prompt string) (string, error) {
b, err := readRaw(prompt)
if err != nil {
return "", err
}
return string(b), nil
}
// ReadPasswordBytes is like ReadPassword but returns a []byte so the
// caller can zeroize the buffer after use.
func ReadPasswordBytes(prompt string) ([]byte, error) {
return readRaw(prompt)
}
func readRaw(prompt string) ([]byte, error) {
fmt.Fprint(os.Stderr, prompt)
b, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // fd fits in int
fmt.Fprintln(os.Stderr)
if err != nil {
return nil, err
}
return b, nil
}

View File

@@ -5,3 +5,4 @@ cmd/tomljson/tomljson
cmd/tomltestgen/tomltestgen cmd/tomltestgen/tomltestgen
dist dist
tests/ tests/
test-results

View File

@@ -1,84 +1,76 @@
[service] version = "2"
golangci-lint-version = "1.39.0"
[linters-settings.wsl]
allow-assign-and-anything = true
[linters-settings.exhaustive]
default-signifies-exhaustive = true
[linters] [linters]
disable-all = true default = "none"
enable = [ enable = [
"asciicheck", "asciicheck",
"bodyclose", "bodyclose",
"cyclop",
"deadcode",
"depguard",
"dogsled", "dogsled",
"dupl", "dupl",
"durationcheck", "durationcheck",
"errcheck", "errcheck",
"errorlint", "errorlint",
"exhaustive", "exhaustive",
# "exhaustivestruct",
"exportloopref",
"forbidigo", "forbidigo",
# "forcetypeassert",
"funlen",
"gci",
# "gochecknoglobals",
"gochecknoinits", "gochecknoinits",
"gocognit",
"goconst", "goconst",
"gocritic", "gocritic",
"gocyclo", "godoclint",
"godot",
"godox",
# "goerr113",
"gofmt",
"gofumpt",
"goheader", "goheader",
"goimports",
"golint",
"gomnd",
# "gomoddirectives",
"gomodguard", "gomodguard",
"goprintffuncname", "goprintffuncname",
"gosec", "gosec",
"gosimple",
"govet", "govet",
# "ifshort",
"importas", "importas",
"ineffassign", "ineffassign",
"lll", "lll",
"makezero", "makezero",
"mirror",
"misspell", "misspell",
"nakedret", "nakedret",
"nestif",
"nilerr", "nilerr",
# "nlreturn",
"noctx", "noctx",
"nolintlint", "nolintlint",
#"paralleltest", "perfsprint",
"prealloc", "prealloc",
"predeclared", "predeclared",
"revive", "revive",
"rowserrcheck", "rowserrcheck",
"sqlclosecheck", "sqlclosecheck",
"staticcheck", "staticcheck",
"structcheck",
"stylecheck",
# "testpackage",
"thelper", "thelper",
"tparallel", "tparallel",
"typecheck",
"unconvert", "unconvert",
"unparam", "unparam",
"unused", "unused",
"varcheck", "usetesting",
"wastedassign", "wastedassign",
"whitespace", "whitespace",
# "wrapcheck", ]
# "wsl"
[linters.settings.exhaustive]
default-signifies-exhaustive = true
[linters.settings.lll]
line-length = 150
[[linters.exclusions.rules]]
path = ".test.go"
linters = ["goconst", "gosec"]
[[linters.exclusions.rules]]
path = "main.go"
linters = ["forbidigo"]
[[linters.exclusions.rules]]
path = "internal"
linters = ["revive"]
text = "(exported|indent-error-flow): "
[formatters]
enable = [
"gci",
"gofmt",
"gofumpt",
"goimports",
] ]

View File

@@ -22,7 +22,6 @@ builds:
- linux_riscv64 - linux_riscv64
- windows_amd64 - windows_amd64
- windows_arm64 - windows_arm64
- windows_arm
- darwin_amd64 - darwin_amd64
- darwin_arm64 - darwin_arm64
- id: tomljson - id: tomljson
@@ -42,7 +41,6 @@ builds:
- linux_riscv64 - linux_riscv64
- windows_amd64 - windows_amd64
- windows_arm64 - windows_arm64
- windows_arm
- darwin_amd64 - darwin_amd64
- darwin_arm64 - darwin_arm64
- id: jsontoml - id: jsontoml
@@ -62,7 +60,6 @@ builds:
- linux_arm - linux_arm
- windows_amd64 - windows_amd64
- windows_arm64 - windows_arm64
- windows_arm
- darwin_amd64 - darwin_amd64
- darwin_arm64 - darwin_arm64
universal_binaries: universal_binaries:

64
vendor/github.com/pelletier/go-toml/v2/AGENTS.md generated vendored Normal file
View File

@@ -0,0 +1,64 @@
# Agent Guidelines for go-toml
This file provides guidelines for AI agents contributing to go-toml. All agents must follow these rules derived from [CONTRIBUTING.md](./CONTRIBUTING.md).
## Project Overview
go-toml is a TOML library for Go. The goal is to provide an easy-to-use and efficient TOML implementation that gets the job done without getting in the way.
## Code Change Rules
### Backward Compatibility
- **No backward-incompatible changes** unless explicitly discussed and approved
- Avoid breaking people's programs unless absolutely necessary
### Testing Requirements
- **All bug fixes must include regression tests**
- **All new code must be tested**
- Run tests before submitting: `go test -race ./...`
- Test coverage must not decrease. Check with:
```bash
go test -covermode=atomic -coverprofile=coverage.out
go tool cover -func=coverage.out
```
- All lines of code touched by changes should be covered by tests
### Performance Requirements
- go-toml aims to stay efficient; avoid performance regressions
- Run benchmarks to verify: `go test ./... -bench=. -count=10`
- Compare results using [benchstat](https://pkg.go.dev/golang.org/x/perf/cmd/benchstat)
### Documentation
- New features or feature extensions must include documentation
- Documentation lives in [README.md](./README.md) and throughout source code
### Code Style
- Follow existing code format and structure
- Code must pass `go fmt`
- Code must pass linting with the same golangci-lint version as CI (see version in `.github/workflows/lint.yml`):
```bash
# Install specific version (check lint.yml for current version)
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin <version>
# Run linter
golangci-lint run ./...
```
### Commit Messages
- Commit messages must explain **why** the change is needed
- Keep messages clear and informative even if details are in the PR description
## Pull Request Checklist
Before submitting:
1. Tests pass (`go test -race ./...`)
2. No backward-incompatible changes (unless discussed)
3. Relevant documentation added/updated
4. No performance regression (verify with benchmarks)
5. Title is clear and understandable for changelog

View File

@@ -33,7 +33,7 @@ The documentation is present in the [README][readme] and thorough the source
code. On release, it gets updated on [pkg.go.dev][pkg.go.dev]. To make a change code. On release, it gets updated on [pkg.go.dev][pkg.go.dev]. To make a change
to the documentation, create a pull request with your proposed changes. For to the documentation, create a pull request with your proposed changes. For
simple changes like that, the easiest way to go is probably the "Fork this simple changes like that, the easiest way to go is probably the "Fork this
project and edit the file" button on Github, displayed at the top right of the project and edit the file" button on GitHub, displayed at the top right of the
file. Unless it's a trivial change (for example a typo), provide a little bit of file. Unless it's a trivial change (for example a typo), provide a little bit of
context in your pull request description or commit message. context in your pull request description or commit message.
@@ -92,6 +92,48 @@ However, given GitHub's new policy to _not_ run Actions on pull requests until a
maintainer clicks on button, it is highly recommended that you run them locally maintainer clicks on button, it is highly recommended that you run them locally
as you make changes. as you make changes.
### Test across Go versions
The repository includes tooling to test go-toml across multiple Go versions
(1.11 through 1.25) both locally and in GitHub Actions.
#### Local testing with Docker
Prerequisites: Docker installed and running, Bash shell, `rsync` command.
```bash
# Test all Go versions in parallel (default)
./test-go-versions.sh
# Test specific versions
./test-go-versions.sh 1.21 1.22 1.23
# Test sequentially (slower but uses less resources)
./test-go-versions.sh --sequential
# Verbose output with custom results directory
./test-go-versions.sh --verbose --output ./my-results 1.24 1.25
# Show all options
./test-go-versions.sh --help
```
The script creates Docker containers for each Go version and runs the full test
suite. Results are saved to a `test-results/` directory with individual logs and
a comprehensive summary report.
The script only exits with a non-zero status code if either of the two most
recent Go versions fail.
#### GitHub Actions testing (maintainers)
1. Go to the **Actions** tab in the GitHub repository
2. Select **"Go Versions Compatibility Test"** from the workflow list
3. Click **"Run workflow"**
4. Optionally customize:
- **Go versions**: Space-separated list (e.g., `1.21 1.22 1.23`)
- **Execution mode**: Parallel (faster) or sequential (more stable)
### Check coverage ### Check coverage
We use `go tool cover` to compute test coverage. Most code editors have a way to We use `go tool cover` to compute test coverage. Most code editors have a way to
@@ -111,7 +153,7 @@ code lowers the coverage.
Go-toml aims to stay efficient. We rely on a set of scenarios executed with Go's Go-toml aims to stay efficient. We rely on a set of scenarios executed with Go's
builtin benchmark systems. Because of their noisy nature, containers provided by builtin benchmark systems. Because of their noisy nature, containers provided by
Github Actions cannot be reliably used for benchmarking. As a result, you are GitHub Actions cannot be reliably used for benchmarking. As a result, you are
responsible for checking that your changes do not incur a performance penalty. responsible for checking that your changes do not incur a performance penalty.
You can run their following to execute benchmarks: You can run their following to execute benchmarks:
@@ -174,7 +216,7 @@ git pull
git tag v2.2.0 git tag v2.2.0
git push --tags git push --tags
``` ```
3. CI automatically builds a draft Github release. Review it and edit as 3. CI automatically builds a draft GitHub release. Review it and edit as
necessary. Look for "Other changes". That would indicate a pull request not necessary. Look for "Other changes". That would indicate a pull request not
labeled properly. Tweak labels and pull request titles until changelog looks labeled properly. Tweak labels and pull request titles until changelog looks
good for users. good for users.

View File

@@ -107,7 +107,11 @@ type MyConfig struct {
### Unmarshaling ### Unmarshaling
[`Unmarshal`][unmarshal] reads a TOML document and fills a Go structure with its [`Unmarshal`][unmarshal] reads a TOML document and fills a Go structure with its
content. For example: content.
Note that the struct variable names are _capitalized_, while the variables in the toml document are _lowercase_.
For example:
```go ```go
doc := ` doc := `
@@ -133,6 +137,62 @@ fmt.Println("tags:", cfg.Tags)
[unmarshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Unmarshal [unmarshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Unmarshal
Here is an example using tables with some simple nesting:
```go
doc := `
age = 45
fruits = ["apple", "pear"]
# these are very important!
[my-variables]
first = 1
second = 0.2
third = "abc"
# this is not so important.
[my-variables.b]
bfirst = 123
`
var Document struct {
Age int
Fruits []string
Myvariables struct {
First int
Second float64
Third string
B struct {
Bfirst int
}
} `toml:"my-variables"`
}
err := toml.Unmarshal([]byte(doc), &Document)
if err != nil {
panic(err)
}
fmt.Println("age:", Document.Age)
fmt.Println("fruits:", Document.Fruits)
fmt.Println("my-variables.first:", Document.Myvariables.First)
fmt.Println("my-variables.second:", Document.Myvariables.Second)
fmt.Println("my-variables.third:", Document.Myvariables.Third)
fmt.Println("my-variables.B.Bfirst:", Document.Myvariables.B.Bfirst)
// Output:
// age: 45
// fruits: [apple pear]
// my-variables.first: 1
// my-variables.second: 0.2
// my-variables.third: abc
// my-variables.B.Bfirst: 123
```
### Marshaling ### Marshaling
[`Marshal`][marshal] is the opposite of Unmarshal: it represents a Go structure [`Marshal`][marshal] is the opposite of Unmarshal: it represents a Go structure
@@ -179,12 +239,12 @@ Execution time speedup compared to other Go TOML libraries:
<tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr> <tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr>
</thead> </thead>
<tbody> <tbody>
<tr><td>Marshal/HugoFrontMatter-2</td><td>1.9x</td><td>2.2x</td></tr> <tr><td>Marshal/HugoFrontMatter-2</td><td>2.1x</td><td>2.0x</td></tr>
<tr><td>Marshal/ReferenceFile/map-2</td><td>1.7x</td><td>2.1x</td></tr> <tr><td>Marshal/ReferenceFile/map-2</td><td>2.0x</td><td>2.0x</td></tr>
<tr><td>Marshal/ReferenceFile/struct-2</td><td>2.2x</td><td>3.0x</td></tr> <tr><td>Marshal/ReferenceFile/struct-2</td><td>2.3x</td><td>2.5x</td></tr>
<tr><td>Unmarshal/HugoFrontMatter-2</td><td>2.9x</td><td>2.7x</td></tr> <tr><td>Unmarshal/HugoFrontMatter-2</td><td>3.3x</td><td>2.8x</td></tr>
<tr><td>Unmarshal/ReferenceFile/map-2</td><td>2.6x</td><td>2.7x</td></tr> <tr><td>Unmarshal/ReferenceFile/map-2</td><td>2.9x</td><td>3.0x</td></tr>
<tr><td>Unmarshal/ReferenceFile/struct-2</td><td>4.6x</td><td>5.1x</td></tr> <tr><td>Unmarshal/ReferenceFile/struct-2</td><td>4.8x</td><td>5.0x</td></tr>
</tbody> </tbody>
</table> </table>
<details><summary>See more</summary> <details><summary>See more</summary>
@@ -197,17 +257,17 @@ provided for completeness.</p>
<tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr> <tr><th>Benchmark</th><th>go-toml v1</th><th>BurntSushi/toml</th></tr>
</thead> </thead>
<tbody> <tbody>
<tr><td>Marshal/SimpleDocument/map-2</td><td>1.8x</td><td>2.7x</td></tr> <tr><td>Marshal/SimpleDocument/map-2</td><td>2.0x</td><td>2.9x</td></tr>
<tr><td>Marshal/SimpleDocument/struct-2</td><td>2.7x</td><td>3.8x</td></tr> <tr><td>Marshal/SimpleDocument/struct-2</td><td>2.5x</td><td>3.6x</td></tr>
<tr><td>Unmarshal/SimpleDocument/map-2</td><td>3.8x</td><td>3.0x</td></tr> <tr><td>Unmarshal/SimpleDocument/map-2</td><td>4.2x</td><td>3.4x</td></tr>
<tr><td>Unmarshal/SimpleDocument/struct-2</td><td>5.6x</td><td>4.1x</td></tr> <tr><td>Unmarshal/SimpleDocument/struct-2</td><td>5.9x</td><td>4.4x</td></tr>
<tr><td>UnmarshalDataset/example-2</td><td>3.0x</td><td>3.2x</td></tr> <tr><td>UnmarshalDataset/example-2</td><td>3.2x</td><td>2.9x</td></tr>
<tr><td>UnmarshalDataset/code-2</td><td>2.3x</td><td>2.9x</td></tr> <tr><td>UnmarshalDataset/code-2</td><td>2.4x</td><td>2.8x</td></tr>
<tr><td>UnmarshalDataset/twitter-2</td><td>2.6x</td><td>2.7x</td></tr> <tr><td>UnmarshalDataset/twitter-2</td><td>2.7x</td><td>2.5x</td></tr>
<tr><td>UnmarshalDataset/citm_catalog-2</td><td>2.2x</td><td>2.3x</td></tr> <tr><td>UnmarshalDataset/citm_catalog-2</td><td>2.3x</td><td>2.3x</td></tr>
<tr><td>UnmarshalDataset/canada-2</td><td>1.8x</td><td>1.5x</td></tr> <tr><td>UnmarshalDataset/canada-2</td><td>1.9x</td><td>1.5x</td></tr>
<tr><td>UnmarshalDataset/config-2</td><td>4.1x</td><td>2.9x</td></tr> <tr><td>UnmarshalDataset/config-2</td><td>5.4x</td><td>3.0x</td></tr>
<tr><td>geomean</td><td>2.7x</td><td>2.8x</td></tr> <tr><td>geomean</td><td>2.9x</td><td>2.8x</td></tr>
</tbody> </tbody>
</table> </table>
<p>This table can be generated with <code>./ci.sh benchmark -a -html</code>.</p> <p>This table can be generated with <code>./ci.sh benchmark -a -html</code>.</p>

View File

@@ -147,7 +147,7 @@ bench() {
pushd "$dir" pushd "$dir"
if [ "${replace}" != "" ]; then if [ "${replace}" != "" ]; then
find ./benchmark/ -iname '*.go' -exec sed -i -E "s|github.com/pelletier/go-toml/v2|${replace}|g" {} \; find ./benchmark/ -iname '*.go' -exec sed -i -E "s|github.com/pelletier/go-toml/v2\"|${replace}\"|g" {} \;
go get "${replace}" go get "${replace}"
fi fi
@@ -195,6 +195,11 @@ for line in reversed(lines[2:]):
"%.1fx" % (float(line[3])/v2), # v1 "%.1fx" % (float(line[3])/v2), # v1
"%.1fx" % (float(line[7])/v2), # bs "%.1fx" % (float(line[7])/v2), # bs
]) ])
if not results:
print("No benchmark results to display.", file=sys.stderr)
sys.exit(1)
# move geomean to the end # move geomean to the end
results.append(results[0]) results.append(results[0])
del results[0] del results[0]

View File

@@ -230,8 +230,8 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, err return t, nil, err
} }
if t.Second > 60 { if t.Second > 59 {
return t, nil, unstable.NewParserError(b[6:8], "seconds cannot be greater 60") return t, nil, unstable.NewParserError(b[6:8], "seconds cannot be greater than 59")
} }
b = b[8:] b = b[8:]
@@ -279,7 +279,6 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, b, nil return t, b, nil
} }
//nolint:cyclop
func parseFloat(b []byte) (float64, error) { func parseFloat(b []byte) (float64, error) {
if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' { if len(b) == 4 && (b[0] == '+' || b[0] == '-') && b[1] == 'n' && b[2] == 'a' && b[3] == 'n' {
return math.NaN(), nil return math.NaN(), nil

View File

@@ -2,10 +2,10 @@ package toml
import ( import (
"fmt" "fmt"
"reflect"
"strconv" "strconv"
"strings" "strings"
"github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/unstable" "github.com/pelletier/go-toml/v2/unstable"
) )
@@ -54,6 +54,18 @@ func (s *StrictMissingError) String() string {
return buf.String() return buf.String()
} }
// Unwrap returns wrapped decode errors
//
// Implements errors.Join() interface.
func (s *StrictMissingError) Unwrap() []error {
errs := make([]error, len(s.Errors))
for i := range s.Errors {
errs[i] = &s.Errors[i]
}
return errs
}
// Key represents a TOML key as a sequence of key parts.
type Key []string type Key []string
// Error returns the error message contained in the DecodeError. // Error returns the error message contained in the DecodeError.
@@ -78,7 +90,7 @@ func (e *DecodeError) Key() Key {
return e.key return e.key
} }
// decodeErrorFromHighlight creates a DecodeError referencing a highlighted // wrapDecodeError creates a DecodeError referencing a highlighted
// range of bytes from document. // range of bytes from document.
// //
// highlight needs to be a sub-slice of document, or this function panics. // highlight needs to be a sub-slice of document, or this function panics.
@@ -88,7 +100,7 @@ func (e *DecodeError) Key() Key {
// //
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *unstable.ParserError) *DecodeError { func wrapDecodeError(document []byte, de *unstable.ParserError) *DecodeError {
offset := danger.SubsliceOffset(document, de.Highlight) offset := subsliceOffset(document, de.Highlight)
errMessage := de.Error() errMessage := de.Error()
errLine, errColumn := positionAtEnd(document[:offset]) errLine, errColumn := positionAtEnd(document[:offset])
@@ -248,5 +260,24 @@ func positionAtEnd(b []byte) (row int, column int) {
} }
} }
return return row, column
}
// subsliceOffset returns the byte offset of subslice within data.
// subslice must share the same backing array as data.
func subsliceOffset(data []byte, subslice []byte) int {
if len(subslice) == 0 {
return 0
}
// Use reflect to get the data pointers of both slices.
// This is safe because we're only reading the pointer values for comparison.
dataPtr := reflect.ValueOf(data).Pointer()
subPtr := reflect.ValueOf(subslice).Pointer()
offset := int(subPtr - dataPtr)
if offset < 0 || offset > len(data) {
panic("subslice is not within data")
}
return offset
} }

View File

@@ -1,6 +1,6 @@
package characters package characters
var invalidAsciiTable = [256]bool{ var invalidASCIITable = [256]bool{
0x00: true, 0x00: true,
0x01: true, 0x01: true,
0x02: true, 0x02: true,
@@ -37,6 +37,6 @@ var invalidAsciiTable = [256]bool{
0x7F: true, 0x7F: true,
} }
func InvalidAscii(b byte) bool { func InvalidASCII(b byte) bool {
return invalidAsciiTable[b] return invalidASCIITable[b]
} }

View File

@@ -1,20 +1,12 @@
// Package characters provides functions for working with string encodings.
package characters package characters
import ( import (
"unicode/utf8" "unicode/utf8"
) )
type utf8Err struct { // Utf8TomlValidAlreadyEscaped verifies that a given string is only made of
Index int // valid UTF-8 characters allowed by the TOML spec:
Size int
}
func (u utf8Err) Zero() bool {
return u.Size == 0
}
// Verified that a given string is only made of valid UTF-8 characters allowed
// by the TOML spec:
// //
// Any Unicode character may be used except those that must be escaped: // Any Unicode character may be used except those that must be escaped:
// quotation mark, backslash, and the control characters other than tab (U+0000 // quotation mark, backslash, and the control characters other than tab (U+0000
@@ -23,8 +15,8 @@ func (u utf8Err) Zero() bool {
// It is a copy of the Go 1.17 utf8.Valid implementation, tweaked to exit early // It is a copy of the Go 1.17 utf8.Valid implementation, tweaked to exit early
// when a character is not allowed. // when a character is not allowed.
// //
// The returned utf8Err is Zero() if the string is valid, or contains the byte // The returned slice is empty if the string is valid, or contains the bytes
// index and size of the invalid character. // of the invalid character.
// //
// quotation mark => already checked // quotation mark => already checked
// backslash => already checked // backslash => already checked
@@ -32,9 +24,8 @@ func (u utf8Err) Zero() bool {
// 0x9 => tab, ok // 0x9 => tab, ok
// 0xA - 0x1F => invalid // 0xA - 0x1F => invalid
// 0x7F => invalid // 0x7F => invalid
func Utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) { func Utf8TomlValidAlreadyEscaped(p []byte) []byte {
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration. // Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
offset := 0
for len(p) >= 8 { for len(p) >= 8 {
// Combining two 32 bit loads allows the same code to be used // Combining two 32 bit loads allows the same code to be used
// for 32 and 64 bit platforms. // for 32 and 64 bit platforms.
@@ -48,24 +39,19 @@ func Utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
} }
for i, b := range p[:8] { for i, b := range p[:8] {
if InvalidAscii(b) { if InvalidASCII(b) {
err.Index = offset + i return p[i : i+1]
err.Size = 1
return
} }
} }
p = p[8:] p = p[8:]
offset += 8
} }
n := len(p) n := len(p)
for i := 0; i < n; { for i := 0; i < n; {
pi := p[i] pi := p[i]
if pi < utf8.RuneSelf { if pi < utf8.RuneSelf {
if InvalidAscii(pi) { if InvalidASCII(pi) {
err.Index = offset + i return p[i : i+1]
err.Size = 1
return
} }
i++ i++
continue continue
@@ -73,44 +59,34 @@ func Utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
x := first[pi] x := first[pi]
if x == xx { if x == xx {
// Illegal starter byte. // Illegal starter byte.
err.Index = offset + i return p[i : i+1]
err.Size = 1
return
} }
size := int(x & 7) size := int(x & 7)
if i+size > n { if i+size > n {
// Short or invalid. // Short or invalid.
err.Index = offset + i return p[i:n]
err.Size = n - i
return
} }
accept := acceptRanges[x>>4] accept := acceptRanges[x>>4]
if c := p[i+1]; c < accept.lo || accept.hi < c { if c := p[i+1]; c < accept.lo || accept.hi < c {
err.Index = offset + i return p[i : i+2]
err.Size = 2 } else if size == 2 { //revive:disable:empty-block
return
} else if size == 2 {
} else if c := p[i+2]; c < locb || hicb < c { } else if c := p[i+2]; c < locb || hicb < c {
err.Index = offset + i return p[i : i+3]
err.Size = 3 } else if size == 3 { //revive:disable:empty-block
return
} else if size == 3 {
} else if c := p[i+3]; c < locb || hicb < c { } else if c := p[i+3]; c < locb || hicb < c {
err.Index = offset + i return p[i : i+4]
err.Size = 4
return
} }
i += size i += size
} }
return return nil
} }
// Return the size of the next rune if valid, 0 otherwise. // Utf8ValidNext returns the size of the next rune if valid, 0 otherwise.
func Utf8ValidNext(p []byte) int { func Utf8ValidNext(p []byte) int {
c := p[0] c := p[0]
if c < utf8.RuneSelf { if c < utf8.RuneSelf {
if InvalidAscii(c) { if InvalidASCII(c) {
return 0 return 0
} }
return 1 return 1
@@ -129,10 +105,10 @@ func Utf8ValidNext(p []byte) int {
accept := acceptRanges[x>>4] accept := acceptRanges[x>>4]
if c := p[1]; c < accept.lo || accept.hi < c { if c := p[1]; c < accept.lo || accept.hi < c {
return 0 return 0
} else if size == 2 { } else if size == 2 { //nolint:revive
} else if c := p[2]; c < locb || hicb < c { } else if c := p[2]; c < locb || hicb < c {
return 0 return 0
} else if size == 3 { } else if size == 3 { //nolint:revive
} else if c := p[3]; c < locb || hicb < c { } else if c := p[3]; c < locb || hicb < c {
return 0 return 0
} }

View File

@@ -1,65 +0,0 @@
package danger
import (
"fmt"
"reflect"
"unsafe"
)
const maxInt = uintptr(int(^uint(0) >> 1))
func SubsliceOffset(data []byte, subslice []byte) int {
datap := (*reflect.SliceHeader)(unsafe.Pointer(&data))
hlp := (*reflect.SliceHeader)(unsafe.Pointer(&subslice))
if hlp.Data < datap.Data {
panic(fmt.Errorf("subslice address (%d) is before data address (%d)", hlp.Data, datap.Data))
}
offset := hlp.Data - datap.Data
if offset > maxInt {
panic(fmt.Errorf("slice offset larger than int (%d)", offset))
}
intoffset := int(offset)
if intoffset > datap.Len {
panic(fmt.Errorf("slice offset (%d) is farther than data length (%d)", intoffset, datap.Len))
}
if intoffset+hlp.Len > datap.Len {
panic(fmt.Errorf("slice ends (%d+%d) is farther than data length (%d)", intoffset, hlp.Len, datap.Len))
}
return intoffset
}
func BytesRange(start []byte, end []byte) []byte {
if start == nil || end == nil {
panic("cannot call BytesRange with nil")
}
startp := (*reflect.SliceHeader)(unsafe.Pointer(&start))
endp := (*reflect.SliceHeader)(unsafe.Pointer(&end))
if startp.Data > endp.Data {
panic(fmt.Errorf("start pointer address (%d) is after end pointer address (%d)", startp.Data, endp.Data))
}
l := startp.Len
endLen := int(endp.Data-startp.Data) + endp.Len
if endLen > l {
l = endLen
}
if l > startp.Cap {
panic(fmt.Errorf("range length is larger than capacity"))
}
return start[:l]
}
func Stride(ptr unsafe.Pointer, size uintptr, offset int) unsafe.Pointer {
// TODO: replace with unsafe.Add when Go 1.17 is released
// https://github.com/golang/go/issues/40481
return unsafe.Pointer(uintptr(ptr) + uintptr(int(size)*offset))
}

View File

@@ -1,23 +0,0 @@
package danger
import (
"reflect"
"unsafe"
)
// typeID is used as key in encoder and decoder caches to enable using
// the optimize runtime.mapaccess2_fast64 function instead of the more
// expensive lookup if we were to use reflect.Type as map key.
//
// typeID holds the pointer to the reflect.Type value, which is unique
// in the program.
//
// https://github.com/segmentio/encoding/blob/master/json/codec.go#L59-L61
type TypeID unsafe.Pointer
func MakeTypeID(t reflect.Type) TypeID {
// reflect.Type has the fields:
// typ unsafe.Pointer
// ptr unsafe.Pointer
return TypeID((*[2]unsafe.Pointer)(unsafe.Pointer(&t))[1])
}

View File

@@ -36,7 +36,7 @@ func (t *KeyTracker) Pop(node *unstable.Node) {
} }
} }
// Key returns the current key // Key returns the current key.
func (t *KeyTracker) Key() []string { func (t *KeyTracker) Key() []string {
k := make([]string, len(t.k)) k := make([]string, len(t.k))
copy(k, t.k) copy(k, t.k)

View File

@@ -288,11 +288,12 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
idx = s.create(parentIdx, k, tableKind, false, true) idx = s.create(parentIdx, k, tableKind, false, true)
} else { } else {
entry := s.entries[idx] entry := s.entries[idx]
if it.IsLast() { switch {
case it.IsLast():
return false, fmt.Errorf("toml: key %s is already defined", string(k)) return false, fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind { case entry.kind != tableKind:
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind) return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit { case entry.explicit:
return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k)) return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
} }
} }
@@ -309,16 +310,16 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
return s.checkInlineTable(value) return s.checkInlineTable(value)
case unstable.Array: case unstable.Array:
return s.checkArray(value) return s.checkArray(value)
} default:
return false, nil return false, nil
} }
}
func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) { func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
switch n.Kind { switch n.Kind { //nolint:exhaustive
case unstable.InlineTable: case unstable.InlineTable:
first, err = s.checkInlineTable(n) first, err = s.checkInlineTable(n)
if err != nil { if err != nil {

View File

@@ -1 +1,2 @@
// Package tracker provides functions for keeping track of AST nodes.
package tracker package tracker

View File

@@ -45,7 +45,7 @@ func (d *LocalDate) UnmarshalText(b []byte) error {
type LocalTime struct { type LocalTime struct {
Hour int // Hour of the day: [0; 24[ Hour int // Hour of the day: [0; 24[
Minute int // Minute of the hour: [0; 60[ Minute int // Minute of the hour: [0; 60[
Second int // Second of the minute: [0; 60[ Second int // Second of the minute: [0; 59]
Nanosecond int // Nanoseconds within the second: [0, 1000000000[ Nanosecond int // Nanoseconds within the second: [0, 1000000000[
Precision int // Number of digits to display for Nanosecond. Precision int // Number of digits to display for Nanosecond.
} }

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding" "encoding"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
@@ -42,7 +43,7 @@ type Encoder struct {
arraysMultiline bool arraysMultiline bool
indentSymbol string indentSymbol string
indentTables bool indentTables bool
marshalJsonNumbers bool marshalJSONNumbers bool
} }
// NewEncoder returns a new Encoder that writes to w. // NewEncoder returns a new Encoder that writes to w.
@@ -89,14 +90,14 @@ func (enc *Encoder) SetIndentTables(indent bool) *Encoder {
return enc return enc
} }
// SetMarshalJsonNumbers forces the encoder to serialize `json.Number` as a // SetMarshalJSONNumbers forces the encoder to serialize `json.Number` as a
// float or integer instead of relying on TextMarshaler to emit a string. // float or integer instead of relying on TextMarshaler to emit a string.
// //
// *Unstable:* This method does not follow the compatibility guarantees of // *Unstable:* This method does not follow the compatibility guarantees of
// semver. It can be changed or removed without a new major version being // semver. It can be changed or removed without a new major version being
// issued. // issued.
func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder { func (enc *Encoder) SetMarshalJSONNumbers(indent bool) *Encoder {
enc.marshalJsonNumbers = indent enc.marshalJSONNumbers = indent
return enc return enc
} }
@@ -161,6 +162,8 @@ func (enc *Encoder) SetMarshalJsonNumbers(indent bool) *Encoder {
// //
// The "omitempty" option prevents empty values or groups from being emitted. // The "omitempty" option prevents empty values or groups from being emitted.
// //
// The "omitzero" option prevents zero values or groups from being emitted.
//
// The "commented" option prefixes the value and all its children with a comment // The "commented" option prefixes the value and all its children with a comment
// symbol. // symbol.
// //
@@ -177,7 +180,7 @@ func (enc *Encoder) Encode(v interface{}) error {
ctx.inline = enc.tablesInline ctx.inline = enc.tablesInline
if v == nil { if v == nil {
return fmt.Errorf("toml: cannot encode a nil interface") return errors.New("toml: cannot encode a nil interface")
} }
b, err := enc.encode(b, ctx, reflect.ValueOf(v)) b, err := enc.encode(b, ctx, reflect.ValueOf(v))
@@ -196,6 +199,7 @@ func (enc *Encoder) Encode(v interface{}) error {
type valueOptions struct { type valueOptions struct {
multiline bool multiline bool
omitempty bool omitempty bool
omitzero bool
commented bool commented bool
comment string comment string
} }
@@ -266,16 +270,15 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case LocalDateTime: case LocalDateTime:
return append(b, x.String()...), nil return append(b, x.String()...), nil
case json.Number: case json.Number:
if enc.marshalJsonNumbers { if enc.marshalJSONNumbers {
if x == "" { /// Useful zero value. if x == "" { /// Useful zero value.
return append(b, "0"...), nil return append(b, "0"...), nil
} else if v, err := x.Int64(); err == nil { } else if v, err := x.Int64(); err == nil {
return enc.encode(b, ctx, reflect.ValueOf(v)) return enc.encode(b, ctx, reflect.ValueOf(v))
} else if f, err := x.Float64(); err == nil { } else if f, err := x.Float64(); err == nil {
return enc.encode(b, ctx, reflect.ValueOf(f)) return enc.encode(b, ctx, reflect.ValueOf(f))
} else {
return nil, fmt.Errorf("toml: unable to convert %q to int64 or float64", x)
} }
return nil, fmt.Errorf("toml: unable to convert %q to int64 or float64", x)
} }
} }
@@ -309,7 +312,7 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
return enc.encodeSlice(b, ctx, v) return enc.encodeSlice(b, ctx, v)
case reflect.Interface: case reflect.Interface:
if v.IsNil() { if v.IsNil() {
return nil, fmt.Errorf("toml: encoding a nil interface is not supported") return nil, errors.New("toml: encoding a nil interface is not supported")
} }
return enc.encode(b, ctx, v.Elem()) return enc.encode(b, ctx, v.Elem())
@@ -326,28 +329,30 @@ func (enc *Encoder) encode(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, e
case reflect.Float32: case reflect.Float32:
f := v.Float() f := v.Float()
if math.IsNaN(f) { switch {
case math.IsNaN(f):
b = append(b, "nan"...) b = append(b, "nan"...)
} else if f > math.MaxFloat32 { case f > math.MaxFloat32:
b = append(b, "inf"...) b = append(b, "inf"...)
} else if f < -math.MaxFloat32 { case f < -math.MaxFloat32:
b = append(b, "-inf"...) b = append(b, "-inf"...)
} else if math.Trunc(f) == f { case math.Trunc(f) == f:
b = strconv.AppendFloat(b, f, 'f', 1, 32) b = strconv.AppendFloat(b, f, 'f', 1, 32)
} else { default:
b = strconv.AppendFloat(b, f, 'f', -1, 32) b = strconv.AppendFloat(b, f, 'f', -1, 32)
} }
case reflect.Float64: case reflect.Float64:
f := v.Float() f := v.Float()
if math.IsNaN(f) { switch {
case math.IsNaN(f):
b = append(b, "nan"...) b = append(b, "nan"...)
} else if f > math.MaxFloat64 { case f > math.MaxFloat64:
b = append(b, "inf"...) b = append(b, "inf"...)
} else if f < -math.MaxFloat64 { case f < -math.MaxFloat64:
b = append(b, "-inf"...) b = append(b, "-inf"...)
} else if math.Trunc(f) == f { case math.Trunc(f) == f:
b = strconv.AppendFloat(b, f, 'f', 1, 64) b = strconv.AppendFloat(b, f, 'f', 1, 64)
} else { default:
b = strconv.AppendFloat(b, f, 'f', -1, 64) b = strconv.AppendFloat(b, f, 'f', -1, 64)
} }
case reflect.Bool: case reflect.Bool:
@@ -384,6 +389,31 @@ func shouldOmitEmpty(options valueOptions, v reflect.Value) bool {
return options.omitempty && isEmptyValue(v) return options.omitempty && isEmptyValue(v)
} }
func shouldOmitZero(options valueOptions, v reflect.Value) bool {
if !options.omitzero {
return false
}
// Check if the type implements isZeroer interface (has a custom IsZero method).
if v.Type().Implements(isZeroerType) {
return v.Interface().(isZeroer).IsZero()
}
// Check if pointer type implements isZeroer.
if reflect.PointerTo(v.Type()).Implements(isZeroerType) {
if v.CanAddr() {
return v.Addr().Interface().(isZeroer).IsZero()
}
// Create a temporary addressable copy to call the pointer receiver method.
pv := reflect.New(v.Type())
pv.Elem().Set(v)
return pv.Interface().(isZeroer).IsZero()
}
// Fall back to reflect's IsZero for types without custom IsZero method.
return v.IsZero()
}
func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) {
var err error var err error
@@ -434,9 +464,10 @@ func isEmptyValue(v reflect.Value) bool {
return v.Float() == 0 return v.Float() == 0
case reflect.Interface, reflect.Ptr: case reflect.Interface, reflect.Ptr:
return v.IsNil() return v.IsNil()
} default:
return false return false
} }
}
func isEmptyStruct(v reflect.Value) bool { func isEmptyStruct(v reflect.Value) bool {
// TODO: merge with walkStruct and cache. // TODO: merge with walkStruct and cache.
@@ -479,7 +510,7 @@ func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byt
func needsQuoting(v string) bool { func needsQuoting(v string) bool {
// TODO: vectorize // TODO: vectorize
for _, b := range []byte(v) { for _, b := range []byte(v) {
if b == '\'' || b == '\r' || b == '\n' || characters.InvalidAscii(b) { if b == '\'' || b == '\r' || b == '\n' || characters.InvalidASCII(b) {
return true return true
} }
} }
@@ -517,12 +548,26 @@ func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byt
del = 0x7f del = 0x7f
) )
for _, r := range []byte(v) { bv := []byte(v)
for i := 0; i < len(bv); i++ {
r := bv[i]
switch r { switch r {
case '\\': case '\\':
b = append(b, `\\`...) b = append(b, `\\`...)
case '"': case '"':
if multiline {
// Quotation marks do not need to be quoted in multiline strings unless
// it contains 3 consecutive. If 3+ quotes appear, quote all of them
// because it's visually better
if i+2 > len(bv) || bv[i+1] != '"' || bv[i+2] != '"' {
b = append(b, r)
} else {
b = append(b, `\"\"\"`...)
i += 2
}
} else {
b = append(b, `\"`...) b = append(b, `\"`...)
}
case '\b': case '\b':
b = append(b, `\b`...) b = append(b, `\b`...)
case '\f': case '\f':
@@ -559,9 +604,9 @@ func (enc *Encoder) encodeUnquotedKey(b []byte, v string) []byte {
return append(b, v...) return append(b, v...)
} }
func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error) { func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) []byte {
if len(ctx.parentKey) == 0 { if len(ctx.parentKey) == 0 {
return b, nil return b
} }
b = enc.encodeComment(ctx.indent, ctx.options.comment, b) b = enc.encodeComment(ctx.indent, ctx.options.comment, b)
@@ -581,10 +626,9 @@ func (enc *Encoder) encodeTableHeader(ctx encoderCtx, b []byte) ([]byte, error)
b = append(b, "]\n"...) b = append(b, "]\n"...)
return b, nil return b
} }
//nolint:cyclop
func (enc *Encoder) encodeKey(b []byte, k string) []byte { func (enc *Encoder) encodeKey(b []byte, k string) []byte {
needsQuotation := false needsQuotation := false
cannotUseLiteral := false cannotUseLiteral := false
@@ -621,31 +665,34 @@ func (enc *Encoder) encodeKey(b []byte, k string) []byte {
func (enc *Encoder) keyToString(k reflect.Value) (string, error) { func (enc *Encoder) keyToString(k reflect.Value) (string, error) {
keyType := k.Type() keyType := k.Type()
switch { if keyType.Implements(textMarshalerType) {
case keyType.Kind() == reflect.String:
return k.String(), nil
case keyType.Implements(textMarshalerType):
keyB, err := k.Interface().(encoding.TextMarshaler).MarshalText() keyB, err := k.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err) return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err)
} }
return string(keyB), nil return string(keyB), nil
}
case keyType.Kind() == reflect.Int || keyType.Kind() == reflect.Int8 || keyType.Kind() == reflect.Int16 || keyType.Kind() == reflect.Int32 || keyType.Kind() == reflect.Int64: switch keyType.Kind() {
case reflect.String:
return k.String(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(k.Int(), 10), nil return strconv.FormatInt(k.Int(), 10), nil
case keyType.Kind() == reflect.Uint || keyType.Kind() == reflect.Uint8 || keyType.Kind() == reflect.Uint16 || keyType.Kind() == reflect.Uint32 || keyType.Kind() == reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(k.Uint(), 10), nil return strconv.FormatUint(k.Uint(), 10), nil
case keyType.Kind() == reflect.Float32: case reflect.Float32:
return strconv.FormatFloat(k.Float(), 'f', -1, 32), nil return strconv.FormatFloat(k.Float(), 'f', -1, 32), nil
case keyType.Kind() == reflect.Float64: case reflect.Float64:
return strconv.FormatFloat(k.Float(), 'f', -1, 64), nil return strconv.FormatFloat(k.Float(), 'f', -1, 64), nil
}
default:
return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind()) return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind())
} }
}
func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte, error) {
var ( var (
@@ -657,9 +704,19 @@ func (enc *Encoder) encodeMap(b []byte, ctx encoderCtx, v reflect.Value) ([]byte
for iter.Next() { for iter.Next() {
v := iter.Value() v := iter.Value()
if isNil(v) { // Handle nil values: convert nil pointers to zero value,
// skip nil interfaces and nil maps.
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
v = reflect.Zero(v.Type().Elem())
}
case reflect.Interface, reflect.Map:
if v.IsNil() {
continue continue
} }
default:
}
k, err := enc.keyToString(iter.Key()) k, err := enc.keyToString(iter.Key())
if err != nil { if err != nil {
@@ -748,9 +805,8 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
walkStruct(ctx, t, f.Elem()) walkStruct(ctx, t, f.Elem())
} }
continue continue
} else {
k = fieldType.Name
} }
k = fieldType.Name
} }
if isNil(f) { if isNil(f) {
@@ -760,6 +816,7 @@ func walkStruct(ctx encoderCtx, t *table, v reflect.Value) {
options := valueOptions{ options := valueOptions{
multiline: opts.multiline, multiline: opts.multiline,
omitempty: opts.omitempty, omitempty: opts.omitempty,
omitzero: opts.omitzero,
commented: opts.commented, commented: opts.commented,
comment: fieldType.Tag.Get("comment"), comment: fieldType.Tag.Get("comment"),
} }
@@ -820,6 +877,7 @@ type tagOptions struct {
multiline bool multiline bool
inline bool inline bool
omitempty bool omitempty bool
omitzero bool
commented bool commented bool
} }
@@ -832,7 +890,7 @@ func parseTag(tag string) (string, tagOptions) {
} }
raw := tag[idx+1:] raw := tag[idx+1:]
tag = string(tag[:idx]) tag = tag[:idx]
for raw != "" { for raw != "" {
var o string var o string
i := strings.Index(raw, ",") i := strings.Index(raw, ",")
@@ -848,6 +906,8 @@ func parseTag(tag string) (string, tagOptions) {
opts.inline = true opts.inline = true
case "omitempty": case "omitempty":
opts.omitempty = true opts.omitempty = true
case "omitzero":
opts.omitzero = true
case "commented": case "commented":
opts.commented = true opts.commented = true
} }
@@ -866,10 +926,7 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
} }
if !ctx.skipTableHeader { if !ctx.skipTableHeader {
b, err = enc.encodeTableHeader(ctx, b) b = enc.encodeTableHeader(ctx, b)
if err != nil {
return nil, err
}
if enc.indentTables && len(ctx.parentKey) > 0 { if enc.indentTables && len(ctx.parentKey) > 0 {
ctx.indent++ ctx.indent++
@@ -882,6 +939,9 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
if shouldOmitEmpty(kv.Options, kv.Value) { if shouldOmitEmpty(kv.Options, kv.Value) {
continue continue
} }
if kv.Options.omitzero && shouldOmitZero(kv.Options, kv.Value) {
continue
}
hasNonEmptyKV = true hasNonEmptyKV = true
ctx.setKey(kv.Key) ctx.setKey(kv.Key)
@@ -901,6 +961,9 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
if shouldOmitEmpty(table.Options, table.Value) { if shouldOmitEmpty(table.Options, table.Value) {
continue continue
} }
if table.Options.omitzero && shouldOmitZero(table.Options, table.Value) {
continue
}
if first { if first {
first = false first = false
if hasNonEmptyKV { if hasNonEmptyKV {
@@ -935,6 +998,9 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
if shouldOmitEmpty(kv.Options, kv.Value) { if shouldOmitEmpty(kv.Options, kv.Value) {
continue continue
} }
if kv.Options.omitzero && shouldOmitZero(kv.Options, kv.Value) {
continue
}
if first { if first {
first = false first = false
@@ -963,11 +1029,14 @@ func willConvertToTable(ctx encoderCtx, v reflect.Value) bool {
if !v.IsValid() { if !v.IsValid() {
return false return false
} }
if v.Type() == timeType || v.Type().Implements(textMarshalerType) || (v.Kind() != reflect.Ptr && v.CanAddr() && reflect.PointerTo(v.Type()).Implements(textMarshalerType)) { t := v.Type()
if t == timeType || t.Implements(textMarshalerType) {
return false
}
if v.Kind() != reflect.Ptr && v.CanAddr() && reflect.PointerTo(t).Implements(textMarshalerType) {
return false return false
} }
t := v.Type()
switch t.Kind() { switch t.Kind() {
case reflect.Map, reflect.Struct: case reflect.Map, reflect.Struct:
return !ctx.inline return !ctx.inline

View File

@@ -1,7 +1,6 @@
package toml package toml
import ( import (
"github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable" "github.com/pelletier/go-toml/v2/unstable"
) )
@@ -13,6 +12,9 @@ type strict struct {
key tracker.KeyTracker key tracker.KeyTracker
missing []unstable.ParserError missing []unstable.ParserError
// Reference to the document for computing key ranges.
doc []byte
} }
func (s *strict) EnterTable(node *unstable.Node) { func (s *strict) EnterTable(node *unstable.Node) {
@@ -53,7 +55,7 @@ func (s *strict) MissingTable(node *unstable.Node) {
} }
s.missing = append(s.missing, unstable.ParserError{ s.missing = append(s.missing, unstable.ParserError{
Highlight: keyLocation(node), Highlight: s.keyLocation(node),
Message: "missing table", Message: "missing table",
Key: s.key.Key(), Key: s.key.Key(),
}) })
@@ -65,7 +67,7 @@ func (s *strict) MissingField(node *unstable.Node) {
} }
s.missing = append(s.missing, unstable.ParserError{ s.missing = append(s.missing, unstable.ParserError{
Highlight: keyLocation(node), Highlight: s.keyLocation(node),
Message: "missing field", Message: "missing field",
Key: s.key.Key(), Key: s.key.Key(),
}) })
@@ -88,7 +90,7 @@ func (s *strict) Error(doc []byte) error {
return err return err
} }
func keyLocation(node *unstable.Node) []byte { func (s *strict) keyLocation(node *unstable.Node) []byte {
k := node.Key() k := node.Key()
hasOne := k.Next() hasOne := k.Next()
@@ -96,12 +98,17 @@ func keyLocation(node *unstable.Node) []byte {
panic("should not be called with empty key") panic("should not be called with empty key")
} }
start := k.Node().Data // Get the range from the first key to the last key.
end := k.Node().Data firstRaw := k.Node().Raw
lastRaw := firstRaw
for k.Next() { for k.Next() {
end = k.Node().Data lastRaw = k.Node().Raw
} }
return danger.BytesRange(start, end) // Compute the slice from the document using the ranges.
start := firstRaw.Offset
end := lastRaw.Offset + lastRaw.Length
return s.doc[start:end]
} }

View File

@@ -0,0 +1,597 @@
#!/usr/bin/env bash
set -uo pipefail
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Go versions to test (1.11 through 1.26)
GO_VERSIONS=(
"1.11"
"1.12"
"1.13"
"1.14"
"1.15"
"1.16"
"1.17"
"1.18"
"1.19"
"1.20"
"1.21"
"1.22"
"1.23"
"1.24"
"1.25"
"1.26"
)
# Default values
PARALLEL=true
VERBOSE=false
OUTPUT_DIR="test-results"
DOCKER_TIMEOUT="10m"
usage() {
cat << EOF
Usage: $0 [OPTIONS] [GO_VERSIONS...]
Test go-toml across multiple Go versions using Docker containers.
The script reports the lowest continuous supported Go version (where all subsequent
versions pass) and only exits with non-zero status if either of the two most recent
Go versions fail, indicating immediate attention is needed.
Note: For Go versions < 1.21, the script automatically updates go.mod to match the
target version, but older versions may still fail due to missing standard library
features (e.g., the 'slices' package introduced in Go 1.21).
OPTIONS:
-h, --help Show this help message
-s, --sequential Run tests sequentially instead of in parallel
-v, --verbose Enable verbose output
-o, --output DIR Output directory for test results (default: test-results)
-t, --timeout TIME Docker timeout for each test (default: 10m)
--list List available Go versions and exit
ARGUMENTS:
GO_VERSIONS Specific Go versions to test (default: all supported versions)
Examples: 1.21 1.22 1.23
EXAMPLES:
$0 # Test all Go versions in parallel
$0 --sequential # Test all Go versions sequentially
$0 1.21 1.22 1.23 # Test specific versions
$0 --verbose --output ./results 1.25 1.26 # Verbose output to custom directory
EXIT CODES:
0 Recent Go versions pass (good compatibility)
1 Recent Go versions fail (needs attention) or script error
EOF
}
log() {
echo -e "${BLUE}[$(date +'%H:%M:%S')]${NC} $*" >&2
}
log_success() {
echo -e "${GREEN}[$(date +'%H:%M:%S')] ✓${NC} $*" >&2
}
log_error() {
echo -e "${RED}[$(date +'%H:%M:%S')] ✗${NC} $*" >&2
}
log_warning() {
echo -e "${YELLOW}[$(date +'%H:%M:%S')] ⚠${NC} $*" >&2
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
usage
exit 0
;;
-s|--sequential)
PARALLEL=false
shift
;;
-v|--verbose)
VERBOSE=true
shift
;;
-o|--output)
OUTPUT_DIR="$2"
shift 2
;;
-t|--timeout)
DOCKER_TIMEOUT="$2"
shift 2
;;
--list)
echo "Available Go versions:"
printf '%s\n' "${GO_VERSIONS[@]}"
exit 0
;;
-*)
echo "Unknown option: $1" >&2
usage
exit 1
;;
*)
# Remaining arguments are Go versions
break
;;
esac
done
# If specific versions provided, use those instead of defaults
if [[ $# -gt 0 ]]; then
GO_VERSIONS=("$@")
fi
# Validate Go versions
for version in "${GO_VERSIONS[@]}"; do
if ! [[ "$version" =~ ^1\.(1[1-9]|2[0-6])$ ]]; then
log_error "Invalid Go version: $version. Supported versions: 1.11-1.26"
exit 1
fi
done
# Check if Docker is available
if ! command -v docker &> /dev/null; then
log_error "Docker is required but not installed or not in PATH"
exit 1
fi
# Check if Docker daemon is running
if ! docker info &> /dev/null; then
log_error "Docker daemon is not running"
exit 1
fi
# Create output directory
mkdir -p "$OUTPUT_DIR"
# Function to test a single Go version
test_go_version() {
local go_version="$1"
local container_name="go-toml-test-${go_version}"
local result_file="${OUTPUT_DIR}/go-${go_version}.txt"
local dockerfile_content
log "Testing Go $go_version..."
# Create a temporary Dockerfile for this version
# For Go versions < 1.21, we need to update go.mod to match the Go version
local needs_go_mod_update=false
if [[ $(echo "$go_version 1.21" | tr ' ' '\n' | sort -V | head -n1) == "$go_version" && "$go_version" != "1.21" ]]; then
needs_go_mod_update=true
fi
dockerfile_content="FROM golang:${go_version}-alpine
# Install git (required for go mod)
RUN apk add --no-cache git
# Set working directory
WORKDIR /app
# Copy source code
COPY . ."
# Add go.mod update step for older Go versions
if [[ "$needs_go_mod_update" == true ]]; then
dockerfile_content="$dockerfile_content
# Update go.mod to match Go version (required for Go < 1.21)
RUN if [ -f go.mod ]; then sed -i 's/^go [0-9]\\+\\.[0-9]\\+\\(\\.[0-9]\\+\\)\\?/go $go_version/' go.mod; fi
# Note: Go versions < 1.21 may fail due to missing standard library packages (e.g., slices)
# This is expected for projects that use Go 1.21+ features"
fi
dockerfile_content="$dockerfile_content
# Run tests
CMD [\"sh\", \"-c\", \"go version && echo '--- Running go test ./... ---' && go test ./...\"]"
# Create temporary directory for this test
local temp_dir
temp_dir=$(mktemp -d)
# Copy source to temp directory (excluding test results and git)
rsync -a --exclude="$OUTPUT_DIR" --exclude=".git" --exclude="*.test" . "$temp_dir/"
# Create Dockerfile in temp directory
echo "$dockerfile_content" > "$temp_dir/Dockerfile"
# Build and run container
local exit_code=0
local output
if $VERBOSE; then
log "Building Docker image for Go $go_version..."
fi
# Capture both stdout and stderr, and the exit code
if output=$(cd "$temp_dir" && timeout "$DOCKER_TIMEOUT" docker build -t "$container_name" . 2>&1 && \
timeout "$DOCKER_TIMEOUT" docker run --rm "$container_name" 2>&1); then
log_success "Go $go_version: PASSED"
echo "PASSED" > "${result_file}.status"
else
exit_code=$?
log_error "Go $go_version: FAILED (exit code: $exit_code)"
echo "FAILED" > "${result_file}.status"
fi
# Save full output
echo "$output" > "$result_file"
# Clean up
docker rmi "$container_name" &> /dev/null || true
rm -rf "$temp_dir"
if $VERBOSE; then
echo "--- Go $go_version output ---"
echo "$output"
echo "--- End Go $go_version output ---"
fi
return $exit_code
}
# Function to run tests in parallel
run_parallel() {
local pids=()
local failed_versions=()
log "Starting parallel tests for ${#GO_VERSIONS[@]} Go versions..."
# Start all tests in background
for version in "${GO_VERSIONS[@]}"; do
test_go_version "$version" &
pids+=($!)
done
# Wait for all tests to complete
for i in "${!pids[@]}"; do
local pid=${pids[$i]}
local version=${GO_VERSIONS[$i]}
if ! wait $pid; then
failed_versions+=("$version")
fi
done
return ${#failed_versions[@]}
}
# Function to run tests sequentially
run_sequential() {
local failed_versions=()
log "Starting sequential tests for ${#GO_VERSIONS[@]} Go versions..."
for version in "${GO_VERSIONS[@]}"; do
if ! test_go_version "$version"; then
failed_versions+=("$version")
fi
done
return ${#failed_versions[@]}
}
# Main execution
main() {
local start_time
start_time=$(date +%s)
log "Starting Go version compatibility tests..."
log "Testing versions: ${GO_VERSIONS[*]}"
log "Output directory: $OUTPUT_DIR"
log "Parallel execution: $PARALLEL"
local failed_count
if $PARALLEL; then
run_parallel
failed_count=$?
else
run_sequential
failed_count=$?
fi
local end_time
end_time=$(date +%s)
local duration=$((end_time - start_time))
# Collect results for display
local passed_versions=()
local failed_versions=()
local unknown_versions=()
local passed_count=0
for version in "${GO_VERSIONS[@]}"; do
local status_file="${OUTPUT_DIR}/go-${version}.txt.status"
if [[ -f "$status_file" ]]; then
local status
status=$(cat "$status_file")
if [[ "$status" == "PASSED" ]]; then
passed_versions+=("$version")
((passed_count++))
else
failed_versions+=("$version")
fi
else
unknown_versions+=("$version")
fi
done
# Generate summary report
local summary_file="${OUTPUT_DIR}/summary.txt"
{
echo "Go Version Compatibility Test Summary"
echo "====================================="
echo "Date: $(date)"
echo "Duration: ${duration}s"
echo "Parallel: $PARALLEL"
echo ""
echo "Results:"
for version in "${GO_VERSIONS[@]}"; do
local status_file="${OUTPUT_DIR}/go-${version}.txt.status"
if [[ -f "$status_file" ]]; then
local status
status=$(cat "$status_file")
if [[ "$status" == "PASSED" ]]; then
echo " Go $version: ✓ PASSED"
else
echo " Go $version: ✗ FAILED"
fi
else
echo " Go $version: ? UNKNOWN (no status file)"
fi
done
echo ""
echo "Summary: $passed_count/${#GO_VERSIONS[@]} versions passed"
if [[ $failed_count -gt 0 ]]; then
echo ""
echo "Failed versions details:"
for version in "${failed_versions[@]}"; do
echo ""
echo "--- Go $version (FAILED) ---"
local result_file="${OUTPUT_DIR}/go-${version}.txt"
if [[ -f "$result_file" ]]; then
tail -n 30 "$result_file"
fi
done
fi
} > "$summary_file"
# Find lowest continuous supported version and check recent versions
local lowest_continuous_version=""
local recent_versions_failed=false
# Sort versions to ensure proper order
local sorted_versions=()
for version in "${GO_VERSIONS[@]}"; do
sorted_versions+=("$version")
done
# Sort versions numerically (1.11, 1.12, ..., 1.25)
IFS=$'\n' sorted_versions=($(sort -V <<< "${sorted_versions[*]}"))
# Find lowest continuous supported version (all versions from this point onwards pass)
for version in "${sorted_versions[@]}"; do
local status_file="${OUTPUT_DIR}/go-${version}.txt.status"
local all_subsequent_pass=true
# Check if this version and all subsequent versions pass
local found_current=false
for check_version in "${sorted_versions[@]}"; do
if [[ "$check_version" == "$version" ]]; then
found_current=true
fi
if [[ "$found_current" == true ]]; then
local check_status_file="${OUTPUT_DIR}/go-${check_version}.txt.status"
if [[ -f "$check_status_file" ]]; then
local status
status=$(cat "$check_status_file")
if [[ "$status" != "PASSED" ]]; then
all_subsequent_pass=false
break
fi
else
all_subsequent_pass=false
break
fi
fi
done
if [[ "$all_subsequent_pass" == true ]]; then
lowest_continuous_version="$version"
break
fi
done
# Check if the two most recent versions failed
local num_versions=${#sorted_versions[@]}
if [[ $num_versions -ge 2 ]]; then
local second_recent="${sorted_versions[$((num_versions-2))]}"
local most_recent="${sorted_versions[$((num_versions-1))]}"
local second_recent_status_file="${OUTPUT_DIR}/go-${second_recent}.txt.status"
local most_recent_status_file="${OUTPUT_DIR}/go-${most_recent}.txt.status"
local second_recent_failed=false
local most_recent_failed=false
if [[ -f "$second_recent_status_file" ]]; then
local status
status=$(cat "$second_recent_status_file")
if [[ "$status" != "PASSED" ]]; then
second_recent_failed=true
fi
else
second_recent_failed=true
fi
if [[ -f "$most_recent_status_file" ]]; then
local status
status=$(cat "$most_recent_status_file")
if [[ "$status" != "PASSED" ]]; then
most_recent_failed=true
fi
else
most_recent_failed=true
fi
if [[ "$second_recent_failed" == true || "$most_recent_failed" == true ]]; then
recent_versions_failed=true
fi
elif [[ $num_versions -eq 1 ]]; then
# Only one version tested, check if it's the most recent and failed
local only_version="${sorted_versions[0]}"
local only_status_file="${OUTPUT_DIR}/go-${only_version}.txt.status"
if [[ -f "$only_status_file" ]]; then
local status
status=$(cat "$only_status_file")
if [[ "$status" != "PASSED" ]]; then
recent_versions_failed=true
fi
else
recent_versions_failed=true
fi
fi
# Display summary
echo ""
log "Test completed in ${duration}s"
log "Summary report: $summary_file"
echo ""
echo "========================================"
echo " FINAL RESULTS"
echo "========================================"
echo ""
# Display passed versions
if [[ ${#passed_versions[@]} -gt 0 ]]; then
log_success "PASSED (${#passed_versions[@]}/${#GO_VERSIONS[@]}):"
# Sort passed versions for display
local sorted_passed=()
for version in "${sorted_versions[@]}"; do
for passed_version in "${passed_versions[@]}"; do
if [[ "$version" == "$passed_version" ]]; then
sorted_passed+=("$version")
break
fi
done
done
for version in "${sorted_passed[@]}"; do
echo -e " ${GREEN}${NC} Go $version"
done
echo ""
fi
# Display failed versions
if [[ ${#failed_versions[@]} -gt 0 ]]; then
log_error "FAILED (${#failed_versions[@]}/${#GO_VERSIONS[@]}):"
# Sort failed versions for display
local sorted_failed=()
for version in "${sorted_versions[@]}"; do
for failed_version in "${failed_versions[@]}"; do
if [[ "$version" == "$failed_version" ]]; then
sorted_failed+=("$version")
break
fi
done
done
for version in "${sorted_failed[@]}"; do
echo -e " ${RED}${NC} Go $version"
done
echo ""
# Show failure details
echo "========================================"
echo " FAILURE DETAILS"
echo "========================================"
echo ""
for version in "${sorted_failed[@]}"; do
echo -e "${RED}--- Go $version FAILURE LOGS (last 30 lines) ---${NC}"
local result_file="${OUTPUT_DIR}/go-${version}.txt"
if [[ -f "$result_file" ]]; then
tail -n 30 "$result_file" | sed 's/^/ /'
else
echo " No log file found: $result_file"
fi
echo ""
done
fi
# Display unknown versions
if [[ ${#unknown_versions[@]} -gt 0 ]]; then
log_warning "UNKNOWN (${#unknown_versions[@]}/${#GO_VERSIONS[@]}):"
for version in "${unknown_versions[@]}"; do
echo -e " ${YELLOW}?${NC} Go $version (no status file)"
done
echo ""
fi
echo "========================================"
echo " COMPATIBILITY SUMMARY"
echo "========================================"
echo ""
if [[ -n "$lowest_continuous_version" ]]; then
log_success "Lowest continuous supported version: Go $lowest_continuous_version"
echo " (All versions from Go $lowest_continuous_version onwards pass)"
else
log_error "No continuous version support found"
echo " (No version has all subsequent versions passing)"
fi
echo ""
echo "========================================"
echo "Full detailed logs available in: $OUTPUT_DIR"
echo "========================================"
# Determine exit code based on recent versions
if [[ "$recent_versions_failed" == true ]]; then
log_error "OVERALL RESULT: Recent Go versions failed - this needs attention!"
if [[ -n "$lowest_continuous_version" ]]; then
echo "Note: Continuous support starts from Go $lowest_continuous_version"
fi
exit 1
else
log_success "OVERALL RESULT: Recent Go versions pass - compatibility looks good!"
if [[ -n "$lowest_continuous_version" ]]; then
echo "Continuous support starts from Go $lowest_continuous_version"
fi
exit 0
fi
}
# Trap to clean up on exit
cleanup() {
# Kill any remaining background processes
jobs -p | xargs -r kill 2>/dev/null || true
# Clean up any remaining Docker containers
docker ps -q --filter "name=go-toml-test-" | xargs -r docker stop 2>/dev/null || true
docker images -q --filter "reference=go-toml-test-*" | xargs -r docker rmi 2>/dev/null || true
}
trap cleanup EXIT
# Run main function
main

View File

@@ -6,9 +6,18 @@ import (
"time" "time"
) )
var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() // isZeroer is used to check if a type has a custom IsZero method.
var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() // This allows custom types to define their own zero-value semantics.
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() type isZeroer interface {
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}(nil)) IsZero() bool
var sliceInterfaceType = reflect.TypeOf([]interface{}(nil)) }
var stringType = reflect.TypeOf("")
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
mapStringInterfaceType = reflect.TypeOf(map[string]interface{}(nil))
sliceInterfaceType = reflect.TypeOf([]interface{}(nil))
stringType = reflect.TypeOf("")
)

View File

@@ -12,7 +12,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable" "github.com/pelletier/go-toml/v2/unstable"
) )
@@ -57,13 +56,18 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
// EnableUnmarshalerInterface allows to enable unmarshaler interface. // EnableUnmarshalerInterface allows to enable unmarshaler interface.
// //
// With this feature enabled, types implementing the unstable/Unmarshaler // With this feature enabled, types implementing the unstable.Unmarshaler
// interface can be decoded from any structure of the document. It allows types // interface can be decoded from any structure of the document. It allows types
// that don't have a straightforward TOML representation to provide their own // that don't have a straightforward TOML representation to provide their own
// decoding logic. // decoding logic.
// //
// Currently, types can only decode from a single value. Tables and array tables // The UnmarshalTOML method receives raw TOML bytes:
// are not supported. // - For single values: the raw value bytes (e.g., `"hello"` for a string)
// - For tables: all key-value lines belonging to that table
// - For inline tables/arrays: the raw bytes of the inline structure
//
// The unstable.RawMessage type can be used to capture raw TOML bytes for
// later processing, similar to json.RawMessage.
// //
// *Unstable:* This method does not follow the compatibility guarantees of // *Unstable:* This method does not follow the compatibility guarantees of
// semver. It can be changed or removed without a new major version being // semver. It can be changed or removed without a new major version being
@@ -123,6 +127,7 @@ func (d *Decoder) Decode(v interface{}) error {
dec := decoder{ dec := decoder{
strict: strict{ strict: strict{
Enabled: d.strict, Enabled: d.strict,
doc: b,
}, },
unmarshalerInterface: d.unmarshalerInterface, unmarshalerInterface: d.unmarshalerInterface,
} }
@@ -226,7 +231,7 @@ func (d *decoder) FromParser(v interface{}) error {
} }
if r.IsNil() { if r.IsNil() {
return fmt.Errorf("toml: decoding pointer target cannot be nil") return errors.New("toml: decoding pointer target cannot be nil")
} }
r = r.Elem() r = r.Elem()
@@ -273,7 +278,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
var err error var err error
var first bool // used for to clear array tables on first use var first bool // used for to clear array tables on first use
if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) { if !d.skipUntilTable || expr.Kind != unstable.KeyValue {
first, err = d.seen.CheckExpression(expr) first, err = d.seen.CheckExpression(expr)
if err != nil { if err != nil {
return err return err
@@ -378,7 +383,7 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(true, v) idx := d.arrayIndex(true, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx) return v, fmt.Errorf("%w at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
@@ -416,28 +421,52 @@ func (d *decoder) handleArrayTableCollection(key unstable.Iterator, v reflect.Va
return v, nil return v, nil
case reflect.Slice: case reflect.Slice:
elem := v.Index(v.Len() - 1) // Create a new element when the slice is empty; otherwise operate on
// the last element.
var (
elem reflect.Value
created bool
)
if v.Len() == 0 {
created = true
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Interface {
elem = makeMapStringInterface()
} else {
elem = reflect.New(elemType).Elem()
}
} else {
elem = v.Index(v.Len() - 1)
}
x, err := d.handleArrayTable(key, elem) x, err := d.handleArrayTable(key, elem)
if err != nil || d.skipUntilTable { if err != nil || d.skipUntilTable {
return reflect.Value{}, err return reflect.Value{}, err
} }
if x.IsValid() { if x.IsValid() {
if created {
elem = x
} else {
elem.Set(x) elem.Set(x)
} }
}
if created {
return reflect.Append(v, elem), nil
}
return v, err return v, err
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(false, v) idx := d.arrayIndex(false, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx) return v, fmt.Errorf("%w at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
return v, err return v, err
} default:
return d.handleArrayTable(key, v) return d.handleArrayTable(key, v)
} }
}
func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) { func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) {
var rv reflect.Value var rv reflect.Value
@@ -470,7 +499,8 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h
mv := v.MapIndex(mk) mv := v.MapIndex(mk)
set := false set := false
if !mv.IsValid() { switch {
case !mv.IsValid():
// If there is no value in the map, create a new one according to // If there is no value in the map, create a new one according to
// the map type. If the element type is interface, create either a // the map type. If the element type is interface, create either a
// map[string]interface{} or a []interface{} depending on whether // map[string]interface{} or a []interface{} depending on whether
@@ -483,13 +513,13 @@ func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn h
mv = reflect.New(t).Elem() mv = reflect.New(t).Elem()
} }
set = true set = true
} else if mv.Kind() == reflect.Interface { case mv.Kind() == reflect.Interface:
mv = mv.Elem() mv = mv.Elem()
if !mv.IsValid() { if !mv.IsValid() {
mv = makeFn() mv = makeFn()
} }
set = true set = true
} else if !mv.CanAddr() { case !mv.CanAddr():
vt := v.Type() vt := v.Type()
t := vt.Elem() t := vt.Elem()
oldmv := mv oldmv := mv
@@ -574,9 +604,8 @@ func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (
// cannot handle it. // cannot handle it.
func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
if v.Len() == 0 { // For non-empty slices, work with the last element
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice") if v.Len() > 0 {
}
elem := v.Index(v.Len() - 1) elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem) x, err := d.handleTable(key, elem)
if err != nil { if err != nil {
@@ -587,6 +616,17 @@ func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.V
} }
return reflect.Value{}, nil return reflect.Value{}, nil
} }
// Empty slice - check if it implements Unmarshaler (e.g., RawMessage)
// and we're at the end of the key path
if d.unmarshalerInterface && !key.Next() {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return d.handleKeyValuesUnmarshaler(outi)
}
}
}
return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
}
if key.Next() { if key.Next() {
// Still scoping the key // Still scoping the key
return d.handleTablePart(key, v) return d.handleTablePart(key, v)
@@ -599,6 +639,24 @@ func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.V
// Handle root expressions until the end of the document or the next // Handle root expressions until the end of the document or the next
// non-key-value. // non-key-value.
func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
// Check if target implements Unmarshaler before processing key-values.
// This allows types to handle entire tables themselves.
if d.unmarshalerInterface {
vv := v
for vv.Kind() == reflect.Ptr {
if vv.IsNil() {
vv.Set(reflect.New(vv.Type().Elem()))
}
vv = vv.Elem()
}
if vv.CanAddr() && vv.Addr().CanInterface() {
if outi, ok := vv.Addr().Interface().(unstable.Unmarshaler); ok {
// Collect all key-value expressions for this table
return d.handleKeyValuesUnmarshaler(outi)
}
}
}
var rv reflect.Value var rv reflect.Value
for d.nextExpr() { for d.nextExpr() {
expr := d.expr() expr := d.expr()
@@ -628,6 +686,41 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
return rv, nil return rv, nil
} }
// handleKeyValuesUnmarshaler collects all key-value expressions for a table
// and passes them to the Unmarshaler as raw TOML bytes.
func (d *decoder) handleKeyValuesUnmarshaler(u unstable.Unmarshaler) (reflect.Value, error) {
// Collect raw bytes from all key-value expressions for this table.
// We use the Raw field on each KeyValue expression to preserve the
// original formatting (whitespace, quoting style, etc.) from the document.
var buf []byte
for d.nextExpr() {
expr := d.expr()
if expr.Kind != unstable.KeyValue {
d.stashExpr()
break
}
_, err := d.seen.CheckExpression(expr)
if err != nil {
return reflect.Value{}, err
}
// Use the raw bytes from the original document to preserve formatting
if expr.Raw.Length > 0 {
raw := d.p.Raw(expr.Raw)
buf = append(buf, raw...)
}
buf = append(buf, '\n')
}
if err := u.UnmarshalTOML(buf); err != nil {
return reflect.Value{}, err
}
return reflect.Value{}, nil
}
type ( type (
handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error) handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
valueMakerFn func() reflect.Value valueMakerFn func() reflect.Value
@@ -672,15 +765,22 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
if d.unmarshalerInterface { if d.unmarshalerInterface {
if v.CanAddr() && v.Addr().CanInterface() { if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok { if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return outi.UnmarshalTOML(value) // Pass raw bytes from the original document
return outi.UnmarshalTOML(d.p.Raw(value.Raw))
} }
} }
} }
// Only try TextUnmarshaler for scalar types. For Array and InlineTable,
// fall through to struct/map unmarshaling to allow flexible unmarshaling
// where a type can implement UnmarshalText for string values but still
// be populated field-by-field from a table. See issue #974.
if value.Kind != unstable.Array && value.Kind != unstable.InlineTable {
ok, err := d.tryTextUnmarshaler(value, v) ok, err := d.tryTextUnmarshaler(value, v)
if ok || err != nil { if ok || err != nil {
return err return err
} }
}
switch value.Kind { switch value.Kind {
case unstable.String: case unstable.String:
@@ -821,6 +921,9 @@ func (d *decoder) unmarshalDateTime(value *unstable.Node, v reflect.Value) error
return err return err
} }
if v.Kind() != reflect.Interface && v.Type() != timeType {
return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("datetime", v.Type()))
}
v.Set(reflect.ValueOf(dt)) v.Set(reflect.ValueOf(dt))
return nil return nil
} }
@@ -831,14 +934,14 @@ func (d *decoder) unmarshalLocalDate(value *unstable.Node, v reflect.Value) erro
return err return err
} }
if v.Kind() != reflect.Interface && v.Type() != timeType {
return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("local date", v.Type()))
}
if v.Type() == timeType { if v.Type() == timeType {
cast := ld.AsTime(time.Local) v.Set(reflect.ValueOf(ld.AsTime(time.Local)))
v.Set(reflect.ValueOf(cast))
return nil return nil
} }
v.Set(reflect.ValueOf(ld)) v.Set(reflect.ValueOf(ld))
return nil return nil
} }
@@ -852,6 +955,9 @@ func (d *decoder) unmarshalLocalTime(value *unstable.Node, v reflect.Value) erro
return unstable.NewParserError(rest, "extra characters at the end of a local time") return unstable.NewParserError(rest, "extra characters at the end of a local time")
} }
if v.Kind() != reflect.Interface {
return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("local time", v.Type()))
}
v.Set(reflect.ValueOf(lt)) v.Set(reflect.ValueOf(lt))
return nil return nil
} }
@@ -866,15 +972,14 @@ func (d *decoder) unmarshalLocalDateTime(value *unstable.Node, v reflect.Value)
return unstable.NewParserError(rest, "extra characters at the end of a local date time") return unstable.NewParserError(rest, "extra characters at the end of a local date time")
} }
if v.Kind() != reflect.Interface && v.Type() != timeType {
return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("local datetime", v.Type()))
}
if v.Type() == timeType { if v.Type() == timeType {
cast := ldt.AsTime(time.Local) v.Set(reflect.ValueOf(ldt.AsTime(time.Local)))
v.Set(reflect.ValueOf(cast))
return nil return nil
} }
v.Set(reflect.ValueOf(ldt)) v.Set(reflect.ValueOf(ldt))
return nil return nil
} }
@@ -929,8 +1034,9 @@ const (
// compile time, so it is computed during initialization. // compile time, so it is computed during initialization.
var maxUint int64 = math.MaxInt64 var maxUint int64 = math.MaxInt64
func init() { func init() { //nolint:gochecknoinits
m := uint64(^uint(0)) m := uint64(^uint(0))
// #nosec G115
if m < uint64(maxUint) { if m < uint64(maxUint) {
maxUint = int64(m) maxUint = int64(m)
} }
@@ -1010,7 +1116,7 @@ func (d *decoder) unmarshalInteger(value *unstable.Node, v reflect.Value) error
case reflect.Interface: case reflect.Interface:
r = reflect.ValueOf(i) r = reflect.ValueOf(i)
default: default:
return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("integer", v.Type())) return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("integer", v.Type()))
} }
if !r.Type().AssignableTo(v.Type()) { if !r.Type().AssignableTo(v.Type()) {
@@ -1029,7 +1135,7 @@ func (d *decoder) unmarshalString(value *unstable.Node, v reflect.Value) error {
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(string(value.Data))) v.Set(reflect.ValueOf(string(value.Data)))
default: default:
return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("string", v.Type())) return unstable.NewParserError(d.p.Raw(value.Raw), "%s", d.typeMismatchString("string", v.Type()))
} }
return nil return nil
@@ -1080,36 +1186,40 @@ func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value,
return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err) return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
} }
return mk.Elem(), nil return mk.Elem(), nil
}
case keyType.Kind() == reflect.Int || keyType.Kind() == reflect.Int8 || keyType.Kind() == reflect.Int16 || keyType.Kind() == reflect.Int32 || keyType.Kind() == reflect.Int64: switch keyType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
key, err := strconv.ParseInt(string(data), 10, 64) key, err := strconv.ParseInt(string(data), 10, 64)
if err != nil { if err != nil {
return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from integer: %w", stringType, err) return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from integer: %w", stringType, err)
} }
return reflect.ValueOf(key).Convert(keyType), nil return reflect.ValueOf(key).Convert(keyType), nil
case keyType.Kind() == reflect.Uint || keyType.Kind() == reflect.Uint8 || keyType.Kind() == reflect.Uint16 || keyType.Kind() == reflect.Uint32 || keyType.Kind() == reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
key, err := strconv.ParseUint(string(data), 10, 64) key, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from unsigned integer: %w", stringType, err) return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from unsigned integer: %w", stringType, err)
} }
return reflect.ValueOf(key).Convert(keyType), nil return reflect.ValueOf(key).Convert(keyType), nil
case keyType.Kind() == reflect.Float32: case reflect.Float32:
key, err := strconv.ParseFloat(string(data), 32) key, err := strconv.ParseFloat(string(data), 32)
if err != nil { if err != nil {
return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err) return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err)
} }
return reflect.ValueOf(float32(key)), nil return reflect.ValueOf(float32(key)), nil
case keyType.Kind() == reflect.Float64: case reflect.Float64:
key, err := strconv.ParseFloat(string(data), 64) key, err := strconv.ParseFloat(string(data), 64)
if err != nil { if err != nil {
return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err) return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err)
} }
return reflect.ValueOf(float64(key)), nil return reflect.ValueOf(float64(key)), nil
}
default:
return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType) return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType)
} }
}
func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
// contains the replacement for v // contains the replacement for v
@@ -1154,6 +1264,18 @@ func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node
case reflect.Struct: case reflect.Struct:
path, found := structFieldPath(v, string(key.Node().Data)) path, found := structFieldPath(v, string(key.Node().Data))
if !found { if !found {
// If no matching struct field is found but the target implements the
// unstable.Unmarshaler interface (and it is enabled), delegate the
// decoding of this value to the custom unmarshaler.
if d.unmarshalerInterface {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
// Pass raw bytes from the original document
return reflect.Value{}, outi.UnmarshalTOML(d.p.Raw(value.Raw))
}
}
}
// Otherwise, keep previous behavior and skip until the next table.
d.skipUntilTable = true d.skipUntilTable = true
break break
} }
@@ -1259,13 +1381,13 @@ func fieldByIndex(v reflect.Value, path []int) reflect.Value {
type fieldPathsMap = map[string][]int type fieldPathsMap = map[string][]int
var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap var globalFieldPathsCache atomic.Value // map[reflect.Type]fieldPathsMap
func structFieldPath(v reflect.Value, name string) ([]int, bool) { func structFieldPath(v reflect.Value, name string) ([]int, bool) {
t := v.Type() t := v.Type()
cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap) cache, _ := globalFieldPathsCache.Load().(map[reflect.Type]fieldPathsMap)
fieldPaths, ok := cache[danger.MakeTypeID(t)] fieldPaths, ok := cache[t]
if !ok { if !ok {
fieldPaths = map[string][]int{} fieldPaths = map[string][]int{}
@@ -1276,8 +1398,8 @@ func structFieldPath(v reflect.Value, name string) ([]int, bool) {
fieldPaths[strings.ToLower(name)] = path fieldPaths[strings.ToLower(name)] = path
}) })
newCache := make(map[danger.TypeID]fieldPathsMap, len(cache)+1) newCache := make(map[reflect.Type]fieldPathsMap, len(cache)+1)
newCache[danger.MakeTypeID(t)] = fieldPaths newCache[t] = fieldPaths
for k, v := range cache { for k, v := range cache {
newCache[k] = v newCache[k] = v
} }
@@ -1301,7 +1423,9 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int))
continue continue
} }
fieldPath := append(path, i) fieldPath := make([]int, 0, len(path)+1)
fieldPath = append(fieldPath, path...)
fieldPath = append(fieldPath, i)
fieldPath = fieldPath[:len(fieldPath):len(fieldPath)] fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]
name := f.Tag.Get("toml") name := f.Tag.Get("toml")

View File

@@ -1,10 +1,8 @@
package unstable package unstable
import ( import (
"errors"
"fmt" "fmt"
"unsafe"
"github.com/pelletier/go-toml/v2/internal/danger"
) )
// Iterator over a sequence of nodes. // Iterator over a sequence of nodes.
@@ -19,30 +17,43 @@ import (
// // do something with n // // do something with n
// } // }
type Iterator struct { type Iterator struct {
nodes *[]Node
idx int32
started bool started bool
node *Node
} }
// Next moves the iterator forward and returns true if points to a // Next moves the iterator forward and returns true if points to a
// node, false otherwise. // node, false otherwise.
func (c *Iterator) Next() bool { func (c *Iterator) Next() bool {
if c.nodes == nil {
return false
}
nodes := *c.nodes
if !c.started { if !c.started {
c.started = true c.started = true
} else if c.node.Valid() { } else {
c.node = c.node.Next() idx := c.idx
if idx >= 0 && int(idx) < len(nodes) {
c.idx = nodes[idx].next
} }
return c.node.Valid() }
return c.idx >= 0 && int(c.idx) < len(nodes)
} }
// IsLast returns true if the current node of the iterator is the last // IsLast returns true if the current node of the iterator is the last
// one. Subsequent calls to Next() will return false. // one. Subsequent calls to Next() will return false.
func (c *Iterator) IsLast() bool { func (c *Iterator) IsLast() bool {
return c.node.next == 0 return c.nodes == nil || c.idx < 0 || (*c.nodes)[c.idx].next < 0
} }
// Node returns a pointer to the node pointed at by the iterator. // Node returns a pointer to the node pointed at by the iterator.
func (c *Iterator) Node() *Node { func (c *Iterator) Node() *Node {
return c.node if c.nodes == nil || c.idx < 0 {
return nil
}
n := &(*c.nodes)[c.idx]
n.nodes = c.nodes
return n
} }
// Node in a TOML expression AST. // Node in a TOML expression AST.
@@ -65,11 +76,12 @@ type Node struct {
Raw Range // Raw bytes from the input. Raw Range // Raw bytes from the input.
Data []byte // Node value (either allocated or referencing the input). Data []byte // Node value (either allocated or referencing the input).
// References to other nodes, as offsets in the backing array // Absolute indices into the backing nodes slice. -1 means none.
// from this node. References can go backward, so those can be next int32
// negative. child int32
next int // 0 if last element
child int // 0 if no child // Reference to the backing nodes slice for navigation.
nodes *[]Node
} }
// Range of bytes in the document. // Range of bytes in the document.
@@ -80,24 +92,24 @@ type Range struct {
// Next returns a pointer to the next node, or nil if there is no next node. // Next returns a pointer to the next node, or nil if there is no next node.
func (n *Node) Next() *Node { func (n *Node) Next() *Node {
if n.next == 0 { if n.next < 0 {
return nil return nil
} }
ptr := unsafe.Pointer(n) next := &(*n.nodes)[n.next]
size := unsafe.Sizeof(Node{}) next.nodes = n.nodes
return (*Node)(danger.Stride(ptr, size, n.next)) return next
} }
// Child returns a pointer to the first child node of this node. Other children // Child returns a pointer to the first child node of this node. Other children
// can be accessed calling Next on the first child. Returns an nil if this Node // can be accessed calling Next on the first child. Returns nil if this Node
// has no child. // has no child.
func (n *Node) Child() *Node { func (n *Node) Child() *Node {
if n.child == 0 { if n.child < 0 {
return nil return nil
} }
ptr := unsafe.Pointer(n) child := &(*n.nodes)[n.child]
size := unsafe.Sizeof(Node{}) child.nodes = n.nodes
return (*Node)(danger.Stride(ptr, size, n.child)) return child
} }
// Valid returns true if the node's kind is set (not to Invalid). // Valid returns true if the node's kind is set (not to Invalid).
@@ -111,13 +123,14 @@ func (n *Node) Valid() bool {
func (n *Node) Key() Iterator { func (n *Node) Key() Iterator {
switch n.Kind { switch n.Kind {
case KeyValue: case KeyValue:
value := n.Child() child := n.child
if !value.Valid() { if child < 0 {
panic(fmt.Errorf("KeyValue should have at least two children")) panic(errors.New("KeyValue should have at least two children"))
} }
return Iterator{node: value.Next()} valueNode := &(*n.nodes)[child]
return Iterator{nodes: n.nodes, idx: valueNode.next}
case Table, ArrayTable: case Table, ArrayTable:
return Iterator{node: n.Child()} return Iterator{nodes: n.nodes, idx: n.child}
default: default:
panic(fmt.Errorf("Key() is not supported on a %s", n.Kind)) panic(fmt.Errorf("Key() is not supported on a %s", n.Kind))
} }
@@ -132,5 +145,5 @@ func (n *Node) Value() *Node {
// Children returns an iterator over a node's children. // Children returns an iterator over a node's children.
func (n *Node) Children() Iterator { func (n *Node) Children() Iterator {
return Iterator{node: n.Child()} return Iterator{nodes: n.nodes, idx: n.child}
} }

View File

@@ -7,15 +7,6 @@ type root struct {
nodes []Node nodes []Node
} }
// Iterator over the top level nodes.
func (r *root) Iterator() Iterator {
it := Iterator{}
if len(r.nodes) > 0 {
it.node = &r.nodes[0]
}
return it
}
func (r *root) at(idx reference) *Node { func (r *root) at(idx reference) *Node {
return &r.nodes[idx] return &r.nodes[idx]
} }
@@ -33,12 +24,10 @@ type builder struct {
lastIdx int lastIdx int
} }
func (b *builder) Tree() *root {
return &b.tree
}
func (b *builder) NodeAt(ref reference) *Node { func (b *builder) NodeAt(ref reference) *Node {
return b.tree.at(ref) n := b.tree.at(ref)
n.nodes = &b.tree.nodes
return n
} }
func (b *builder) Reset() { func (b *builder) Reset() {
@@ -48,24 +37,28 @@ func (b *builder) Reset() {
func (b *builder) Push(n Node) reference { func (b *builder) Push(n Node) reference {
b.lastIdx = len(b.tree.nodes) b.lastIdx = len(b.tree.nodes)
n.next = -1
n.child = -1
b.tree.nodes = append(b.tree.nodes, n) b.tree.nodes = append(b.tree.nodes, n)
return reference(b.lastIdx) return reference(b.lastIdx)
} }
func (b *builder) PushAndChain(n Node) reference { func (b *builder) PushAndChain(n Node) reference {
newIdx := len(b.tree.nodes) newIdx := len(b.tree.nodes)
n.next = -1
n.child = -1
b.tree.nodes = append(b.tree.nodes, n) b.tree.nodes = append(b.tree.nodes, n)
if b.lastIdx >= 0 { if b.lastIdx >= 0 {
b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx b.tree.nodes[b.lastIdx].next = int32(newIdx) //nolint:gosec // TOML ASTs are small
} }
b.lastIdx = newIdx b.lastIdx = newIdx
return reference(b.lastIdx) return reference(b.lastIdx)
} }
func (b *builder) AttachChild(parent reference, child reference) { func (b *builder) AttachChild(parent reference, child reference) {
b.tree.nodes[parent].child = int(child) - int(parent) b.tree.nodes[parent].child = int32(child) //nolint:gosec // TOML ASTs are small
} }
func (b *builder) Chain(from reference, to reference) { func (b *builder) Chain(from reference, to reference) {
b.tree.nodes[from].next = int(to) - int(from) b.tree.nodes[from].next = int32(to) //nolint:gosec // TOML ASTs are small
} }

View File

@@ -6,28 +6,40 @@ import "fmt"
type Kind int type Kind int
const ( const (
// Meta // Invalid represents an invalid meta node.
Invalid Kind = iota Invalid Kind = iota
// Comment represents a comment meta node.
Comment Comment
// Key represents a key meta node.
Key Key
// Top level structures // Table represents a top-level table.
Table Table
// ArrayTable represents a top-level array table.
ArrayTable ArrayTable
// KeyValue represents a top-level key value.
KeyValue KeyValue
// Containers values // Array represents an array container value.
Array Array
// InlineTable represents an inline table container value.
InlineTable InlineTable
// Values // String represents a string value.
String String
// Bool represents a boolean value.
Bool Bool
// Float represents a floating point value.
Float Float
// Integer represents an integer value.
Integer Integer
// LocalDate represents a a local date value.
LocalDate LocalDate
// LocalTime represents a local time value.
LocalTime LocalTime
// LocalDateTime represents a local date/time value.
LocalDateTime LocalDateTime
// DateTime represents a data/time value.
DateTime DateTime
) )

View File

@@ -6,7 +6,6 @@ import (
"unicode" "unicode"
"github.com/pelletier/go-toml/v2/internal/characters" "github.com/pelletier/go-toml/v2/internal/characters"
"github.com/pelletier/go-toml/v2/internal/danger"
) )
// ParserError describes an error relative to the content of the document. // ParserError describes an error relative to the content of the document.
@@ -70,11 +69,26 @@ func (p *Parser) Data() []byte {
// panics. // panics.
func (p *Parser) Range(b []byte) Range { func (p *Parser) Range(b []byte) Range {
return Range{ return Range{
Offset: uint32(danger.SubsliceOffset(p.data, b)), Offset: uint32(p.subsliceOffset(b)), //nolint:gosec // TOML documents are small
Length: uint32(len(b)), Length: uint32(len(b)), //nolint:gosec // TOML documents are small
} }
} }
// rangeOfToken computes the Range of a token given the remaining bytes after the token.
// This is used when the token was extracted from the beginning of some position,
// and 'rest' is what remains after the token.
func (p *Parser) rangeOfToken(token, rest []byte) Range {
offset := len(p.data) - len(token) - len(rest)
return Range{Offset: uint32(offset), Length: uint32(len(token))} //nolint:gosec // TOML documents are small
}
// subsliceOffset returns the byte offset of subslice b within p.data.
// b must be a suffix (tail) of p.data.
func (p *Parser) subsliceOffset(b []byte) int {
// b is a suffix of p.data, so its offset is len(p.data) - len(b)
return len(p.data) - len(b)
}
// Raw returns the slice corresponding to the bytes in the given range. // Raw returns the slice corresponding to the bytes in the given range.
func (p *Parser) Raw(raw Range) []byte { func (p *Parser) Raw(raw Range) []byte {
return p.data[raw.Offset : raw.Offset+raw.Length] return p.data[raw.Offset : raw.Offset+raw.Length]
@@ -158,9 +172,17 @@ type Shape struct {
End Position End Position
} }
func (p *Parser) position(b []byte) Position { // Shape returns the shape of the given range in the input. Will
offset := danger.SubsliceOffset(p.data, b) // panic if the range is not a subslice of the input.
func (p *Parser) Shape(r Range) Shape {
return Shape{
Start: p.positionAt(int(r.Offset)),
End: p.positionAt(int(r.Offset + r.Length)),
}
}
// positionAt returns the position at the given byte offset in the document.
func (p *Parser) positionAt(offset int) Position {
lead := p.data[:offset] lead := p.data[:offset]
return Position{ return Position{
@@ -170,16 +192,6 @@ func (p *Parser) position(b []byte) Position {
} }
} }
// Shape returns the shape of the given range in the input. Will
// panic if the range is not a subslice of the input.
func (p *Parser) Shape(r Range) Shape {
raw := p.Raw(r)
return Shape{
Start: p.position(raw),
End: p.position(raw[r.Length:]),
}
}
func (p *Parser) parseNewline(b []byte) ([]byte, error) { func (p *Parser) parseNewline(b []byte) ([]byte, error) {
if b[0] == '\n' { if b[0] == '\n' {
return b[1:], nil return b[1:], nil
@@ -199,7 +211,7 @@ func (p *Parser) parseComment(b []byte) (reference, []byte, error) {
if p.KeepComments && err == nil { if p.KeepComments && err == nil {
ref = p.builder.Push(Node{ ref = p.builder.Push(Node{
Kind: Comment, Kind: Comment,
Raw: p.Range(data), Raw: p.rangeOfToken(data, rest),
Data: data, Data: data,
}) })
} }
@@ -316,6 +328,9 @@ func (p *Parser) parseStdTable(b []byte) (reference, []byte, error) {
func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) { func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) {
// keyval = key keyval-sep val // keyval = key keyval-sep val
// Track the start position for Raw range
startB := b
ref := p.builder.Push(Node{ ref := p.builder.Push(Node{
Kind: KeyValue, Kind: KeyValue,
}) })
@@ -330,7 +345,7 @@ func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 { if len(b) == 0 {
return invalidReference, nil, NewParserError(b, "expected = after a key, but the document ends there") return invalidReference, nil, NewParserError(startB[:len(startB)-len(b)], "expected = after a key, but the document ends there")
} }
b, err = expect('=', b) b, err = expect('=', b)
@@ -348,6 +363,11 @@ func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) {
p.builder.Chain(valRef, key) p.builder.Chain(valRef, key)
p.builder.AttachChild(ref, valRef) p.builder.AttachChild(ref, valRef)
// Set Raw to span the entire key-value expression.
// Access the node directly in the slice to avoid the write barrier
// that NodeAt's nodes-pointer setup would trigger.
p.builder.tree.nodes[ref].Raw = p.rangeOfToken(startB[:len(startB)-len(b)], b)
return ref, b, err return ref, b, err
} }
@@ -376,7 +396,7 @@ func (p *Parser) parseVal(b []byte) (reference, []byte, error) {
if err == nil { if err == nil {
ref = p.builder.Push(Node{ ref = p.builder.Push(Node{
Kind: String, Kind: String,
Raw: p.Range(raw), Raw: p.rangeOfToken(raw, b),
Data: v, Data: v,
}) })
} }
@@ -394,7 +414,7 @@ func (p *Parser) parseVal(b []byte) (reference, []byte, error) {
if err == nil { if err == nil {
ref = p.builder.Push(Node{ ref = p.builder.Push(Node{
Kind: String, Kind: String,
Raw: p.Range(raw), Raw: p.rangeOfToken(raw, b),
Data: v, Data: v,
}) })
} }
@@ -456,7 +476,7 @@ func (p *Parser) parseInlineTable(b []byte) (reference, []byte, error) {
// inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ] // inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ]
parent := p.builder.Push(Node{ parent := p.builder.Push(Node{
Kind: InlineTable, Kind: InlineTable,
Raw: p.Range(b[:1]), Raw: p.rangeOfToken(b[:1], b[1:]),
}) })
first := true first := true
@@ -542,7 +562,7 @@ func (p *Parser) parseValArray(b []byte) (reference, []byte, error) {
var err error var err error
for len(b) > 0 { for len(b) > 0 {
cref := invalidReference var cref reference
cref, b, err = p.parseOptionalWhitespaceCommentNewline(b) cref, b, err = p.parseOptionalWhitespaceCommentNewline(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@@ -611,12 +631,13 @@ func (p *Parser) parseOptionalWhitespaceCommentNewline(b []byte) (reference, []b
latestCommentRef := invalidReference latestCommentRef := invalidReference
addComment := func(ref reference) { addComment := func(ref reference) {
if rootCommentRef == invalidReference { switch {
case rootCommentRef == invalidReference:
rootCommentRef = ref rootCommentRef = ref
} else if latestCommentRef == invalidReference { case latestCommentRef == invalidReference:
p.builder.AttachChild(rootCommentRef, ref) p.builder.AttachChild(rootCommentRef, ref)
latestCommentRef = ref latestCommentRef = ref
} else { default:
p.builder.Chain(latestCommentRef, ref) p.builder.Chain(latestCommentRef, ref)
latestCommentRef = ref latestCommentRef = ref
} }
@@ -704,11 +725,11 @@ func (p *Parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := characters.Utf8TomlValidAlreadyEscaped(str) highlight := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if len(highlight) == 0 {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(highlight, "invalid UTF-8")
} }
var builder bytes.Buffer var builder bytes.Buffer
@@ -744,7 +765,7 @@ func (p *Parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
i += j i += j
for ; i < len(token)-3; i++ { for ; i < len(token)-3; i++ {
c := token[i] c := token[i]
if !(c == '\n' || c == '\r' || c == ' ' || c == '\t') { if c != '\n' && c != '\r' && c != ' ' && c != '\t' {
i-- i--
break break
} }
@@ -820,7 +841,7 @@ func (p *Parser) parseKey(b []byte) (reference, []byte, error) {
ref := p.builder.Push(Node{ ref := p.builder.Push(Node{
Kind: Key, Kind: Key,
Raw: p.Range(raw), Raw: p.rangeOfToken(raw, b),
Data: key, Data: key,
}) })
@@ -836,7 +857,7 @@ func (p *Parser) parseKey(b []byte) (reference, []byte, error) {
p.builder.PushAndChain(Node{ p.builder.PushAndChain(Node{
Kind: Key, Kind: Key,
Raw: p.Range(raw), Raw: p.rangeOfToken(raw, b),
Data: key, Data: key,
}) })
} else { } else {
@@ -897,11 +918,11 @@ func (p *Parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// validate the string and return a direct reference to the buffer. // validate the string and return a direct reference to the buffer.
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := characters.Utf8TomlValidAlreadyEscaped(str) highlight := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if len(highlight) == 0 {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(highlight, "invalid UTF-8")
} }
i := startIdx i := startIdx
@@ -972,7 +993,7 @@ func hexToRune(b []byte, length int) (rune, error) {
var r uint32 var r uint32
for i, c := range b { for i, c := range b {
d := uint32(0) var d uint32
switch { switch {
case '0' <= c && c <= '9': case '0' <= c && c <= '9':
d = uint32(c - '0') d = uint32(c - '0')
@@ -1013,7 +1034,7 @@ func (p *Parser) parseIntOrFloatOrDateTime(b []byte) (reference, []byte, error)
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: Float, Kind: Float,
Data: b[:3], Data: b[:3],
Raw: p.Range(b[:3]), Raw: p.rangeOfToken(b[:3], b[3:]),
}), b[3:], nil }), b[3:], nil
case 'n': case 'n':
if !scanFollowsNan(b) { if !scanFollowsNan(b) {
@@ -1023,7 +1044,7 @@ func (p *Parser) parseIntOrFloatOrDateTime(b []byte) (reference, []byte, error)
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: Float, Kind: Float,
Data: b[:3], Data: b[:3],
Raw: p.Range(b[:3]), Raw: p.rangeOfToken(b[:3], b[3:]),
}), b[3:], nil }), b[3:], nil
case '+', '-': case '+', '-':
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
@@ -1076,7 +1097,7 @@ byteLoop:
} }
case c == 'T' || c == 't' || c == ':' || c == '.': case c == 'T' || c == 't' || c == ':' || c == '.':
hasTime = true hasTime = true
case c == '+' || c == '-' || c == 'Z' || c == 'z': case c == '+' || c == 'Z' || c == 'z':
hasTz = true hasTz = true
case c == ' ': case c == ' ':
if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) { if !seenSpace && i+1 < len(b) && isDigit(b[i+1]) {
@@ -1148,7 +1169,7 @@ func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: Integer, Kind: Integer,
Data: b[:i], Data: b[:i],
Raw: p.Range(b[:i]), Raw: p.rangeOfToken(b[:i], b[i:]),
}), b[i:], nil }), b[i:], nil
} }
@@ -1172,7 +1193,7 @@ func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
Raw: p.Range(b[:i+3]), Raw: p.rangeOfToken(b[:i+3], b[i+3:]),
}), b[i+3:], nil }), b[i+3:], nil
} }
@@ -1184,7 +1205,7 @@ func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
Raw: p.Range(b[:i+3]), Raw: p.rangeOfToken(b[:i+3], b[i+3:]),
}), b[i+3:], nil }), b[i+3:], nil
} }
@@ -1207,7 +1228,7 @@ func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
return p.builder.Push(Node{ return p.builder.Push(Node{
Kind: kind, Kind: kind,
Data: b[:i], Data: b[:i],
Raw: p.Range(b[:i]), Raw: p.rangeOfToken(b[:i], b[i:]),
}), b[i:], nil }), b[i:], nil
} }

View File

@@ -1,7 +1,32 @@
package unstable package unstable
// The Unmarshaler interface may be implemented by types to customize their // Unmarshaler is implemented by types that can unmarshal a TOML
// behavior when being unmarshaled from a TOML document. // description of themselves. The input is a valid TOML document
// containing the relevant portion of the parsed document.
//
// For tables (including split tables defined in multiple places),
// the data contains the raw key-value bytes from the original document
// with adjusted table headers to be relative to the unmarshaling target.
type Unmarshaler interface { type Unmarshaler interface {
UnmarshalTOML(value *Node) error UnmarshalTOML(data []byte) error
}
// RawMessage is a raw encoded TOML value. It implements Unmarshaler
// and can be used to delay TOML decoding or capture raw content.
//
// Example usage:
//
// type Config struct {
// Plugin RawMessage `toml:"plugin"`
// }
//
// var cfg Config
// toml.NewDecoder(r).EnableUnmarshalerInterface().Decode(&cfg)
// // cfg.Plugin now contains the raw TOML bytes for [plugin]
type RawMessage []byte
// UnmarshalTOML implements Unmarshaler.
func (m *RawMessage) UnmarshalTOML(data []byte) error {
*m = append((*m)[0:0], data...)
return nil
} }

27
vendor/golang.org/x/exp/LICENSE generated vendored
View File

@@ -1,27 +0,0 @@
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/exp/PATENTS generated vendored
View File

@@ -1,22 +0,0 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

View File

@@ -1,54 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package constraints defines a set of useful constraints to be used
// with type parameters.
package constraints
import "cmp"
// Signed is a constraint that permits any signed integer type.
// If future releases of Go add new predeclared signed integer types,
// this constraint will be modified to include them.
type Signed interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}
// Unsigned is a constraint that permits any unsigned integer type.
// If future releases of Go add new predeclared unsigned integer types,
// this constraint will be modified to include them.
type Unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
}
// Integer is a constraint that permits any integer type.
// If future releases of Go add new predeclared integer types,
// this constraint will be modified to include them.
type Integer interface {
Signed | Unsigned
}
// Float is a constraint that permits any floating-point type.
// If future releases of Go add new predeclared floating-point types,
// this constraint will be modified to include them.
type Float interface {
~float32 | ~float64
}
// Complex is a constraint that permits any complex numeric type.
// If future releases of Go add new predeclared complex numeric types,
// this constraint will be modified to include them.
type Complex interface {
~complex64 | ~complex128
}
// Ordered is a constraint that permits any ordered type: any type
// that supports the operators < <= >= >.
// If future releases of Go add new ordered types,
// this constraint will be modified to include them.
//
// This type is redundant since Go 1.21 introduced [cmp.Ordered].
//
//go:fix inline
type Ordered = cmp.Ordered

View File

@@ -33,17 +33,21 @@ guidelines, there may be valid reasons to do so, but it should be rare.
## Guidelines for Pull Requests ## Guidelines for Pull Requests
How to get your contributions merged smoothly and quickly: Please read the following carefully to ensure your contributions can be merged
smoothly and quickly.
### PR Contents
- Create **small PRs** that are narrowly focused on **addressing a single - Create **small PRs** that are narrowly focused on **addressing a single
concern**. We often receive PRs that attempt to fix several things at the same concern**. We often receive PRs that attempt to fix several things at the same
time, and if one part of the PR has a problem, that will hold up the entire time, and if one part of the PR has a problem, that will hold up the entire
PR. PR.
- For **speculative changes**, consider opening an issue and discussing it - If your change does not address an **open issue** with an **agreed
first. If you are suggesting a behavioral or API change, consider starting resolution**, consider opening an issue and discussing it first. If you are
with a [gRFC proposal](https://github.com/grpc/proposal). Many new features suggesting a behavioral or API change, consider starting with a [gRFC
that are not bug fixes will require cross-language agreement. proposal](https://github.com/grpc/proposal). Many new features that are not
bug fixes will require cross-language agreement.
- If you want to fix **formatting or style**, consider whether your changes are - If you want to fix **formatting or style**, consider whether your changes are
an obvious improvement or might be considered a personal preference. If a an obvious improvement or might be considered a personal preference. If a
@@ -56,16 +60,6 @@ How to get your contributions merged smoothly and quickly:
often written as "iff". Please do not make spelling correction changes unless often written as "iff". Please do not make spelling correction changes unless
you are certain they are misspellings. you are certain they are misspellings.
- Provide a good **PR description** as a record of **what** change is being made
and **why** it was made. Link to a GitHub issue if it exists.
- Maintain a **clean commit history** and use **meaningful commit messages**.
PRs with messy commit histories are difficult to review and won't be merged.
Before sending your PR, ensure your changes are based on top of the latest
`upstream/master` commits, and avoid rebasing in the middle of a code review.
You should **never use `git push -f`** unless absolutely necessary during a
review, as it can interfere with GitHub's tracking of comments.
- **All tests need to be passing** before your change can be merged. We - **All tests need to be passing** before your change can be merged. We
recommend you run tests locally before creating your PR to catch breakages recommend you run tests locally before creating your PR to catch breakages
early on: early on:
@@ -81,15 +75,80 @@ How to get your contributions merged smoothly and quickly:
GitHub, which will trigger a GitHub Actions run that you can use to verify GitHub, which will trigger a GitHub Actions run that you can use to verify
everything is passing. everything is passing.
- If you are adding a new file, make sure it has the **copyright message** - Note that there are two GitHub actions checks that need not be green:
1. We test the freshness of the generated proto code we maintain via the
`vet-proto` check. If the source proto files are updated, but our repo is
not updated, an optional checker will fail. This will be fixed by our team
in a separate PR and will not prevent the merge of your PR.
2. We run a checker that will fail if there is any change in dependencies of
an exported package via the `dependencies` check. If new dependencies are
added that are not appropriate, we may not accept your PR (see below).
- If you are adding a **new file**, make sure it has the **copyright message**
template at the top as a comment. You can copy the message from an existing template at the top as a comment. You can copy the message from an existing
file and update the year. file and update the year.
- The grpc package should only depend on standard Go packages and a small number - The grpc package should only depend on standard Go packages and a small number
of exceptions. **If your contribution introduces new dependencies**, you will of exceptions. **If your contribution introduces new dependencies**, you will
need a discussion with gRPC-Go maintainers. A GitHub action check will run on need a discussion with gRPC-Go maintainers.
every PR, and will flag any transitive dependency changes from any public
package. ### PR Descriptions
- **PR titles** should start with the name of the component being addressed, or
the type of change. Examples: transport, client, server, round_robin, xds,
cleanup, deps.
- Read and follow the **guidelines for PR titles and descriptions** here:
https://google.github.io/eng-practices/review/developer/cl-descriptions.html
*particularly* the sections "First Line" and "Body is Informative".
Note: your PR description will be used as the git commit message in a
squash-and-merge if your PR is approved. We may make changes to this as
necessary.
- **Does this PR relate to an open issue?** On the first line, please use the
tag `Fixes #<issue>` to ensure the issue is closed when the PR is merged. Or
use `Updates #<issue>` if the PR is related to an open issue, but does not fix
it. Consider filing an issue if one does not already exist.
- PR descriptions *must* conclude with **release notes** as follows:
```
RELEASE NOTES:
* <component>: <summary>
```
This need not match the PR title.
The summary must:
* be something that gRPC users will understand.
* clearly explain the feature being added, the issue being fixed, or the
behavior being changed, etc. If fixing a bug, be clear about how the bug
can be triggered by an end-user.
* begin with a capital letter and use complete sentences.
* be as short as possible to describe the change being made.
If a PR is *not* end-user visible -- e.g. a cleanup, testing change, or
GitHub-related, use `RELEASE NOTES: n/a`.
### PR Process
- Please **self-review** your code changes before sending your PR. This will
prevent simple, obvious errors from causing delays.
- Maintain a **clean commit history** and use **meaningful commit messages**.
PRs with messy commit histories are difficult to review and won't be merged.
Before sending your PR, ensure your changes are based on top of the latest
`upstream/master` commits, and avoid rebasing in the middle of a code review.
You should **never use `git push -f`** unless absolutely necessary during a
review, as it can interfere with GitHub's tracking of comments.
- Unless your PR is trivial, you should **expect reviewer comments** that you - Unless your PR is trivial, you should **expect reviewer comments** that you
will need to address before merging. We'll label the PR as `Status: Requires will need to address before merging. We'll label the PR as `Status: Requires
@@ -98,5 +157,3 @@ How to get your contributions merged smoothly and quickly:
`stale`, and we will automatically close it after 7 days if we don't hear back `stale`, and we will automatically close it after 7 days if we don't hear back
from you. Please feel free to ping issues or bugs if you do not get a response from you. Please feel free to ping issues or bugs if you do not get a response
within a week. within a week.
- Exceptions to the rules can be made if there's a compelling reason to do so.

View File

@@ -9,21 +9,19 @@ for general contribution guidelines.
## Maintainers (in alphabetical order) ## Maintainers (in alphabetical order)
- [aranjans](https://github.com/aranjans), Google LLC
- [arjan-bal](https://github.com/arjan-bal), Google LLC - [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC - [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc. - [atollena](https://github.com/atollena), Datadog, Inc.
- [dfawley](https://github.com/dfawley), Google LLC - [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC - [easwars](https://github.com/easwars), Google LLC
- [erm-g](https://github.com/erm-g), Google LLC
- [gtcooke94](https://github.com/gtcooke94), Google LLC - [gtcooke94](https://github.com/gtcooke94), Google LLC
- [purnesh42h](https://github.com/purnesh42h), Google LLC
- [zasweq](https://github.com/zasweq), Google LLC
## Emeritus Maintainers (in alphabetical order) ## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez) - [adelez](https://github.com/adelez)
- [aranjans](https://github.com/aranjans)
- [canguler](https://github.com/canguler) - [canguler](https://github.com/canguler)
- [cesarghali](https://github.com/cesarghali) - [cesarghali](https://github.com/cesarghali)
- [erm-g](https://github.com/erm-g)
- [iamqizhao](https://github.com/iamqizhao) - [iamqizhao](https://github.com/iamqizhao)
- [jeanbza](https://github.com/jeanbza) - [jeanbza](https://github.com/jeanbza)
- [jtattermusch](https://github.com/jtattermusch) - [jtattermusch](https://github.com/jtattermusch)
@@ -32,5 +30,7 @@ for general contribution guidelines.
- [matt-kwong](https://github.com/matt-kwong) - [matt-kwong](https://github.com/matt-kwong)
- [menghanl](https://github.com/menghanl) - [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble) - [nicolasnoble](https://github.com/nicolasnoble)
- [purnesh42h](https://github.com/purnesh42h)
- [srini100](https://github.com/srini100) - [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni) - [yongni](https://github.com/yongni)
- [zasweq](https://github.com/zasweq)

View File

@@ -75,8 +75,6 @@ func unregisterForTesting(name string) {
func init() { func init() {
internal.BalancerUnregister = unregisterForTesting internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
} }
// Get returns the resolver builder registered with the given name. // Get returns the resolver builder registered with the given name.

View File

@@ -37,6 +37,8 @@ import (
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
var randIntN = rand.IntN
// ChildState is the balancer state of a child along with the endpoint which // ChildState is the balancer state of a child along with the endpoint which
// identifies the child balancer. // identifies the child balancer.
type ChildState struct { type ChildState struct {
@@ -112,6 +114,21 @@ type endpointSharding struct {
mu sync.Mutex mu sync.Mutex
} }
// rotateEndpoints returns a slice of all the input endpoints rotated a random
// amount.
func rotateEndpoints(es []resolver.Endpoint) []resolver.Endpoint {
les := len(es)
if les == 0 {
return es
}
r := randIntN(les)
// Make a copy to avoid mutating data beyond the end of es.
ret := make([]resolver.Endpoint, les)
copy(ret, es[r:])
copy(ret[les-r:], es[:r])
return ret
}
// UpdateClientConnState creates a child for new endpoints and deletes children // UpdateClientConnState creates a child for new endpoints and deletes children
// for endpoints that are no longer present. It also updates all the children, // for endpoints that are no longer present. It also updates all the children,
// and sends a single synchronous update of the childrens' aggregated state at // and sends a single synchronous update of the childrens' aggregated state at
@@ -133,7 +150,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
newChildren := resolver.NewEndpointMap[*balancerWrapper]() newChildren := resolver.NewEndpointMap[*balancerWrapper]()
// Update/Create new children. // Update/Create new children.
for _, endpoint := range state.ResolverState.Endpoints { for _, endpoint := range rotateEndpoints(state.ResolverState.Endpoints) {
if _, ok := newChildren.Get(endpoint); ok { if _, ok := newChildren.Get(endpoint); ok {
// Endpoint child was already created, continue to avoid duplicate // Endpoint child was already created, continue to avoid duplicate
// update. // update.
@@ -279,7 +296,7 @@ func (es *endpointSharding) updateState() {
p := &pickerWithChildStates{ p := &pickerWithChildStates{
pickers: pickers, pickers: pickers,
childStates: childStates, childStates: childStates,
next: uint32(rand.IntN(len(pickers))), next: uint32(randIntN(len(pickers))),
} }
es.cc.UpdateState(balancer.State{ es.cc.UpdateState(balancer.State{
ConnectivityState: aggState, ConnectivityState: aggState,

View File

@@ -26,6 +26,8 @@ import (
var ( var (
// RandShuffle pseudo-randomizes the order of addresses. // RandShuffle pseudo-randomizes the order of addresses.
RandShuffle = rand.Shuffle RandShuffle = rand.Shuffle
// RandFloat64 returns, as a float64, a pseudo-random number in [0.0,1.0).
RandFloat64 = rand.Float64
// TimeAfterFunc allows mocking the timer for testing connection delay // TimeAfterFunc allows mocking the timer for testing connection delay
// related functionality. // related functionality.
TimeAfterFunc = func(d time.Duration, f func()) func() { TimeAfterFunc = func(d time.Duration, f func()) func() {

File diff suppressed because it is too large Load Diff

View File

@@ -1,906 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pickfirstleaf contains the pick_first load balancing policy which
// will be the universal leaf policy after dualstack changes are implemented.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package pickfirstleaf
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func init() {
if envconfig.NewPickFirstEnabled {
// Register as the default pick_first balancer.
Name = "pick_first"
}
balancer.Register(pickfirstBuilder{})
}
// enableHealthListenerKeyType is a unique key type used in resolver
// attributes to indicate whether the health listener usage is enabled.
type enableHealthListenerKeyType struct{}
var (
logger = grpclog.Component("pick-first-leaf-lb")
// Name is the name of the pick_first_leaf balancer.
// It is changed to "pick_first" in init() if this balancer is to be
// registered as the default pickfirst.
Name = "pick_first_leaf"
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "disconnection",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
)
const (
// TODO: change to pick-first when this becomes the default pick_first policy.
logPrefix = "[pick-first-leaf-lb %p] "
// connectionDelayInterval is the time to wait for during the happy eyeballs
// pass before starting the next connection attempt.
connectionDelayInterval = 250 * time.Millisecond
)
type ipAddrFamily int
const (
// ipAddrFamilyUnknown represents strings that can't be parsed as an IP
// address.
ipAddrFamilyUnknown ipAddrFamily = iota
ipAddrFamilyV4
ipAddrFamilyV6
)
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{
cc: cc,
target: bo.Target.String(),
metricsRecorder: cc.MetricsRecorder(),
subConns: resolver.NewAddressMapV2[*scData](),
state: connectivity.Connecting,
cancelConnectionTimer: func() {},
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (b pickfirstBuilder) Name() string {
return Name
}
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
return cfg, nil
}
// EnableHealthListener updates the state to configure pickfirst for using a
// generic health listener.
func EnableHealthListener(state resolver.State) resolver.State {
state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true)
return state
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
// scData keeps track of the current state of the subConn.
// It is not safe for concurrent access.
type scData struct {
// The following fields are initialized at build time and read-only after
// that.
subConn balancer.SubConn
addr resolver.Address
rawConnectivityState connectivity.State
// The effective connectivity state based on raw connectivity, health state
// and after following sticky TransientFailure behaviour defined in A62.
effectiveState connectivity.State
lastErr error
connectionFailedInFirstPass bool
}
func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) {
sd := &scData{
rawConnectivityState: connectivity.Idle,
effectiveState: connectivity.Idle,
addr: addr,
}
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
StateListener: func(state balancer.SubConnState) {
b.updateSubConnState(sd, state)
},
})
if err != nil {
return nil, err
}
sd.subConn = sc
return sd, nil
}
type pickfirstBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
logger *internalgrpclog.PrefixLogger
cc balancer.ClientConn
target string
metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil
// The mutex is used to ensure synchronization of updates triggered
// from the idle picker and the already serialized resolver,
// SubConn state updates.
mu sync.Mutex
// State reported to the channel based on SubConn states and resolver
// updates.
state connectivity.State
// scData for active subonns mapped by address.
subConns *resolver.AddressMapV2[*scData]
addressList addressList
firstPass bool
numTF int
cancelConnectionTimer func()
healthCheckingEnabled bool
}
// ResolverError is called by the ClientConn when the name resolver produces
// an error or when pickfirst determined the resolver update to be invalid.
func (b *pickfirstBalancer) ResolverError(err error) {
b.mu.Lock()
defer b.mu.Unlock()
b.resolverErrorLocked(err)
}
func (b *pickfirstBalancer) resolverErrorLocked(err error) {
if b.logger.V(2) {
b.logger.Infof("Received error from the name resolver: %v", err)
}
// The picker will not change since the balancer does not currently
// report an error. If the balancer hasn't received a single good resolver
// update yet, transition to TRANSIENT_FAILURE.
if b.state != connectivity.TransientFailure && b.addressList.size() > 0 {
if b.logger.V(2) {
b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.")
}
return
}
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
})
}
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
b.mu.Lock()
defer b.mu.Unlock()
b.cancelConnectionTimer()
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// Cleanup state pertaining to the previous resolver state.
// Treat an empty address list like an error by calling b.ResolverError.
b.closeSubConnsLocked()
b.addressList.updateAddrs(nil)
b.resolverErrorLocked(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState)
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var newAddrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for
// each of the endpoints, in order." - A61
for _, endpoint := range endpoints {
newAddrs = append(newAddrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
newAddrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
newAddrs = append([]resolver.Address{}, newAddrs...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}
// If an address appears in multiple endpoints or in the same endpoint
// multiple times, we keep it only once. We will create only one SubConn
// for the address because an AddressMap is used to store SubConns.
// Not de-duplicating would result in attempting to connect to the same
// SubConn multiple times in the same pass. We don't want this.
newAddrs = deDupAddresses(newAddrs)
newAddrs = interleaveAddresses(newAddrs)
prevAddr := b.addressList.currentAddress()
prevSCData, found := b.subConns.Get(prevAddr)
prevAddrsCount := b.addressList.size()
isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready
b.addressList.updateAddrs(newAddrs)
// If the previous ready SubConn exists in new address list,
// keep this connection and don't create new SubConns.
if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) {
return nil
}
b.reconcileSubConnsLocked(newAddrs)
// If it's the first resolver update or the balancer was already READY
// (but the new address list does not contain the ready SubConn) or
// CONNECTING, enter CONNECTING.
// We may be in TRANSIENT_FAILURE due to a previous empty address list,
// we should still enter CONNECTING because the sticky TF behaviour
// mentioned in A62 applies only when the TRANSIENT_FAILURE is reported
// due to connectivity failures.
if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 {
// Start connection attempt at first address.
b.forceUpdateConcludedStateLocked(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
b.startFirstPassLocked()
} else if b.state == connectivity.TransientFailure {
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
b.startFirstPassLocked()
}
return nil
}
// UpdateSubConnState is unused as a StateListener is always registered when
// creating SubConns.
func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state)
}
func (b *pickfirstBalancer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closeSubConnsLocked()
b.cancelConnectionTimer()
b.state = connectivity.Shutdown
}
// ExitIdle moves the balancer out of idle state. It can be called concurrently
// by the idlePicker and clientConn so access to variables should be
// synchronized.
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle {
b.startFirstPassLocked()
}
}
func (b *pickfirstBalancer) startFirstPassLocked() {
b.firstPass = true
b.numTF = 0
// Reset the connection attempt record for existing SubConns.
for _, sd := range b.subConns.Values() {
sd.connectionFailedInFirstPass = false
}
b.requestConnectionLocked()
}
func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() {
sd.subConn.Shutdown()
}
b.subConns = resolver.NewAddressMapV2[*scData]()
}
// deDupAddresses ensures that each address appears only once in the slice.
func deDupAddresses(addrs []resolver.Address) []resolver.Address {
seenAddrs := resolver.NewAddressMapV2[*scData]()
retAddrs := []resolver.Address{}
for _, addr := range addrs {
if _, ok := seenAddrs.Get(addr); ok {
continue
}
retAddrs = append(retAddrs, addr)
}
return retAddrs
}
// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6)
// as per RFC-8305 section 4.
// Whichever address family is first in the list is followed by an address of
// the other address family; that is, if the first address in the list is IPv6,
// then the first IPv4 address should be moved up in the list to be second in
// the list. It doesn't support configuring "First Address Family Count", i.e.
// there will always be a single member of the first address family at the
// beginning of the interleaved list.
// Addresses that are neither IPv4 nor IPv6 are treated as part of a third
// "unknown" family for interleaving.
// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6
func interleaveAddresses(addrs []resolver.Address) []resolver.Address {
familyAddrsMap := map[ipAddrFamily][]resolver.Address{}
interleavingOrder := []ipAddrFamily{}
for _, addr := range addrs {
family := addressFamily(addr.Addr)
if _, found := familyAddrsMap[family]; !found {
interleavingOrder = append(interleavingOrder, family)
}
familyAddrsMap[family] = append(familyAddrsMap[family], addr)
}
interleavedAddrs := make([]resolver.Address, 0, len(addrs))
for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) {
// Some IP types may have fewer addresses than others, so we look for
// the next type that has a remaining member to add to the interleaved
// list.
family := interleavingOrder[curFamilyIdx]
remainingMembers := familyAddrsMap[family]
if len(remainingMembers) > 0 {
interleavedAddrs = append(interleavedAddrs, remainingMembers[0])
familyAddrsMap[family] = remainingMembers[1:]
}
}
return interleavedAddrs
}
// addressFamily returns the ipAddrFamily after parsing the address string.
// If the address isn't of the format "ip-address:port", it returns
// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when
// using a resolver like passthrough where the address may be a hostname in
// some format that the dialer can resolve.
func addressFamily(address string) ipAddrFamily {
// Parse the IP after removing the port.
host, _, err := net.SplitHostPort(address)
if err != nil {
return ipAddrFamilyUnknown
}
ip, err := netip.ParseAddr(host)
if err != nil {
return ipAddrFamilyUnknown
}
switch {
case ip.Is4() || ip.Is4In6():
return ipAddrFamilyV4
case ip.Is6():
return ipAddrFamilyV6
default:
return ipAddrFamilyUnknown
}
}
// reconcileSubConnsLocked updates the active subchannels based on a new address
// list from the resolver. It does this by:
// - closing subchannels: any existing subchannels associated with addresses
// that are no longer in the updated list are shut down.
// - removing subchannels: entries for these closed subchannels are removed
// from the subchannel map.
//
// This ensures that the subchannel map accurately reflects the current set of
// addresses received from the name resolver.
func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) {
newAddrsMap := resolver.NewAddressMapV2[bool]()
for _, addr := range newAddrs {
newAddrsMap.Set(addr, true)
}
for _, oldAddr := range b.subConns.Keys() {
if _, ok := newAddrsMap.Get(oldAddr); ok {
continue
}
val, _ := b.subConns.Get(oldAddr)
val.subConn.Shutdown()
b.subConns.Delete(oldAddr)
}
}
// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn
// becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
b.cancelConnectionTimer()
for _, sd := range b.subConns.Values() {
if sd.subConn != selected.subConn {
sd.subConn.Shutdown()
}
}
b.subConns = resolver.NewAddressMapV2[*scData]()
b.subConns.Set(selected.addr, selected)
}
// requestConnectionLocked starts connecting on the subchannel corresponding to
// the current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted until a subchannel is found.
func (b *pickfirstBalancer) requestConnectionLocked() {
if !b.addressList.isValid() {
return
}
var lastErr error
for valid := true; valid; valid = b.addressList.increment() {
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
// We want to assign the new scData to sd from the outer scope,
// hence we can't use := below.
sd, err = b.newSCData(curAddr)
if err != nil {
// This should never happen, unless the clientConn is being shut
// down.
if b.logger.V(2) {
b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err)
}
// Do nothing, the LB policy will be closed soon.
return
}
b.subConns.Set(curAddr, sd)
}
switch sd.rawConnectivityState {
case connectivity.Idle:
sd.subConn.Connect()
b.scheduleNextConnectionLocked()
return
case connectivity.TransientFailure:
// The SubConn is being re-used and failed during a previous pass
// over the addressList. It has not completed backoff yet.
// Mark it as having failed and try the next address.
sd.connectionFailedInFirstPass = true
lastErr = sd.lastErr
continue
case connectivity.Connecting:
// Wait for the connection attempt to complete or the timer to fire
// before attempting the next address.
b.scheduleNextConnectionLocked()
return
default:
b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState)
return
}
}
// All the remaining addresses in the list are in TRANSIENT_FAILURE, end the
// first pass if possible.
b.endFirstPassIfPossibleLocked(lastErr)
}
func (b *pickfirstBalancer) scheduleNextConnectionLocked() {
b.cancelConnectionTimer()
if !b.addressList.hasNext() {
return
}
curAddr := b.addressList.currentAddress()
cancelled := false // Access to this is protected by the balancer's mutex.
closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() {
b.mu.Lock()
defer b.mu.Unlock()
// If the scheduled task is cancelled while acquiring the mutex, return.
if cancelled {
return
}
if b.logger.V(2) {
b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr)
}
if b.addressList.increment() {
b.requestConnectionLocked()
}
})
// Access to the cancellation callback held by the balancer is guarded by
// the balancer's mutex, so it's safe to set the boolean from the callback.
b.cancelConnectionTimer = sync.OnceFunc(func() {
cancelled = true
closeFn()
})
}
func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
oldState := sd.rawConnectivityState
sd.rawConnectivityState = newState.ConnectivityState
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes this
// SubConn.
if !b.isActiveSCData(sd) {
return
}
if newState.ConnectivityState == connectivity.Shutdown {
sd.effectiveState = connectivity.Shutdown
return
}
// Record a connection attempt when exiting CONNECTING.
if newState.ConnectivityState == connectivity.TransientFailure {
sd.connectionFailedInFirstPass = true
connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target)
}
if newState.ConnectivityState == connectivity.Ready {
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
b.shutdownRemainingLocked(sd)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one SubConn after
// entering READY. The SubConn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
if !b.healthCheckingEnabled {
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn)
}
sd.effectiveState = connectivity.Ready
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
return
}
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn)
}
// Send a CONNECTING update to take the SubConn out of sticky-TF if
// required.
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) {
b.updateSubConnHealthState(sd, scs)
})
return
}
// If the LB policy is READY, and it receives a subchannel state change,
// it means that the READY subchannel has failed.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) {
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.shutdownRemainingLocked(sd)
sd.effectiveState = newState.ConnectivityState
// READY SubConn interspliced in between CONNECTING and IDLE, need to
// account for that.
if oldState == connectivity.Connecting {
// A known issue (https://github.com/grpc/grpc-go/issues/7862)
// causes a race that prevents the READY state change notification.
// This works around it.
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
}
disconnectionsMetric.Record(b.metricsRecorder, 1, b.target)
b.addressList.reset()
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)},
})
return
}
if b.firstPass {
switch newState.ConnectivityState {
case connectivity.Connecting:
// The effective state can be in either IDLE, CONNECTING or
// TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in
// TRANSIENT_FAILURE until it's READY. See A62.
if sd.effectiveState != connectivity.TransientFailure {
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
}
case connectivity.TransientFailure:
sd.lastErr = newState.ConnectionError
sd.effectiveState = connectivity.TransientFailure
// Since we're re-using common SubConns while handling resolver
// updates, we could receive an out of turn TRANSIENT_FAILURE from
// a pass over the previous address list. Happy Eyeballs will also
// cause out of order updates to arrive.
if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
b.cancelConnectionTimer()
if b.addressList.increment() {
b.requestConnectionLocked()
return
}
}
// End the first pass if we've seen a TRANSIENT_FAILURE from all
// SubConns once.
b.endFirstPassIfPossibleLocked(newState.ConnectionError)
}
return
}
// We have finished the first pass, keep re-connecting failing SubConns.
switch newState.ConnectivityState {
case connectivity.TransientFailure:
b.numTF = (b.numTF + 1) % b.subConns.Len()
sd.lastErr = newState.ConnectionError
if b.numTF%b.subConns.Len() == 0 {
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: newState.ConnectionError},
})
}
// We don't need to request re-resolution since the SubConn already
// does that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from SubConn into
// pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
}
// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the
// addresses are tried and their SubConns have reported a failure.
func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
// An optimization to avoid iterating over the entire SubConn map.
if b.addressList.isValid() {
return
}
// Connect() has been called on all the SubConns. The first pass can be
// ended if all the SubConns have reported a failure.
for _, sd := range b.subConns.Values() {
if !sd.connectionFailedInFirstPass {
return
}
}
b.firstPass = false
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: lastErr},
})
// Start re-connecting all the SubConns that are already in IDLE.
for _, sd := range b.subConns.Values() {
if sd.rawConnectivityState == connectivity.Idle {
sd.subConn.Connect()
}
}
}
func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool {
activeSD, found := b.subConns.Get(sd.addr)
return found && activeSD == sd
}
func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes
// this SubConn.
if !b.isActiveSCData(sd) {
return
}
sd.effectiveState = state.ConnectivityState
switch state.ConnectivityState {
case connectivity.Ready:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
case connectivity.TransientFailure:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)},
})
case connectivity.Connecting:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
default:
b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state)
}
}
// updateBalancerState stores the state reported to the channel and calls
// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate
// updates to the channel.
func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) {
// In case of TransientFailures allow the picker to be updated to update
// the connectivity error, in all other cases don't send duplicate state
// updates.
if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure {
return
}
b.forceUpdateConcludedStateLocked(newState)
}
// forceUpdateConcludedStateLocked stores the state reported to the channel and
// calls ClientConn.UpdateState().
// A separate function is defined to force update the ClientConn state since the
// channel doesn't correctly assume that LB policies start in CONNECTING and
// relies on LB policy to send an initial CONNECTING update.
func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) {
b.state = newState.ConnectivityState
b.cc.UpdateState(newState)
}
type picker struct {
result balancer.PickResult
err error
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// addressList manages sequentially iterating over addresses present in a list
// of endpoints. It provides a 1 dimensional view of the addresses present in
// the endpoints.
// This type is not safe for concurrent access.
type addressList struct {
addresses []resolver.Address
idx int
}
func (al *addressList) isValid() bool {
return al.idx < len(al.addresses)
}
func (al *addressList) size() int {
return len(al.addresses)
}
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
return false
}
al.idx++
return al.idx < len(al.addresses)
}
// currentAddress returns the current address pointed to in the addressList.
// If the list is in an invalid state, it returns an empty address instead.
func (al *addressList) currentAddress() resolver.Address {
if !al.isValid() {
return resolver.Address{}
}
return al.addresses[al.idx]
}
func (al *addressList) reset() {
al.idx = 0
}
func (al *addressList) updateAddrs(addrs []resolver.Address) {
al.addresses = addrs
al.reset()
}
// seekTo returns false if the needle was not found and the current index was
// left unchanged.
func (al *addressList) seekTo(needle resolver.Address) bool {
for ai, addr := range al.addresses {
if !equalAddressIgnoringBalAttributes(&addr, &needle) {
continue
}
al.idx = ai
return true
}
return false
}
// hasNext returns whether incrementing the addressList will result in moving
// past the end of the list. If the list has already moved past the end, it
// returns false.
func (al *addressList) hasNext() bool {
if !al.isValid() {
return false
}
return al.idx+1 < len(al.addresses)
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered
// equal. This is different from the Equal method on the resolver.Address type
// which considers all fields to determine equality. Here, we only consider
// fields that are meaningful to the SubConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes)
}

View File

@@ -26,7 +26,7 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding" "google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" "google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog"
) )
@@ -47,7 +47,7 @@ func (bb builder) Name() string {
} }
func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
childBuilder := balancer.Get(pickfirstleaf.Name).Build childBuilder := balancer.Get(pickfirst.Name).Build
bal := &rrBalancer{ bal := &rrBalancer{
cc: cc, cc: cc,
Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}), Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}),
@@ -67,6 +67,6 @@ func (b *rrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{ return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
// Enable the health listener in pickfirst children for client side health // Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured. // checks and outlier detection, if configured.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState), ResolverState: pickfirst.EnableHealthListener(ccs.ResolverState),
}) })
} }

View File

@@ -111,20 +111,6 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure, // ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil. // describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
}
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
} }
// A Producer is a type shared among potentially many consumers. It is // A Producer is a type shared among potentially many consumers. It is

View File

@@ -36,7 +36,6 @@ import (
) )
var ( var (
setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// noOpRegisterHealthListenerFn is used when client side health checking is // noOpRegisterHealthListenerFn is used when client side health checking is
// disabled. It sends a single READY update on the registered listener. // disabled. It sends a single READY update on the registered listener.
noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() { noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() {
@@ -305,7 +304,7 @@ func newHealthData(s connectivity.State) *healthData {
// updateState is invoked by grpc to push a subConn state update to the // updateState is invoked by grpc to push a subConn state update to the
// underlying balancer. // underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) { func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil { if ctx.Err() != nil || acbw.ccb.balancer == nil {
return return
@@ -317,9 +316,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve
// opts.StateListener is set, so this cannot ever be nil. // opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed. // TODO: delete this comment when UpdateSubConnState is removed.
scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err} scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
// Invalidate the health listener by updating the healthData. // Invalidate the health listener by updating the healthData.
acbw.healthMu.Lock() acbw.healthMu.Lock()
// A race may occur if a health listener is registered soon after the // A race may occur if a health listener is registered soon after the
@@ -450,13 +446,14 @@ func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func(
if acbw.ccb.cc.dopts.disableHealthCheck { if acbw.ccb.cc.dopts.disableHealthCheck {
return noOpRegisterHealthListenerFn return noOpRegisterHealthListenerFn
} }
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
return noOpRegisterHealthListenerFn
}
regHealthLisFn := internal.RegisterClientHealthCheckListener regHealthLisFn := internal.RegisterClientHealthCheckListener
if regHealthLisFn == nil { if regHealthLisFn == nil {
// The health package is not imported. // The health package is not imported.
return noOpRegisterHealthListenerFn channelz.Error(logger, acbw.ac.channelz, "Health check is requested but health package is not imported.")
}
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
return noOpRegisterHealthListenerFn return noOpRegisterHealthListenerFn
} }
return func(ctx context.Context, listener func(balancer.SubConnState)) func() { return func(ctx context.Context, listener func(balancer.SubConnState)) func() {

View File

@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.6 // protoc-gen-go v1.36.10
// protoc v5.27.1 // protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto // source: grpc/binlog/v1/binarylog.proto

View File

@@ -35,16 +35,19 @@ import (
"google.golang.org/grpc/balancer/pickfirst" "google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle" "google.golang.org/grpc/internal/idle"
iresolver "google.golang.org/grpc/internal/resolver" iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stats" istats "google.golang.org/grpc/internal/stats"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
@@ -97,6 +100,41 @@ var (
errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)") errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
) )
var (
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "{disconnection}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality", "grpc.disconnect_error"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"},
Default: false,
})
openConnectionsMetric = expstats.RegisterInt64UpDownCount(expstats.MetricDescriptor{
Name: "grpc.subchannel.open_connections",
Description: "EXPERIMENTAL. Number of open connections.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.security_level", "grpc.lb.locality"},
Default: false,
})
)
const ( const (
defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4 defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
defaultClientMaxSendMessageSize = math.MaxInt32 defaultClientMaxSendMessageSize = math.MaxInt32
@@ -208,9 +246,10 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority) channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz) cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers) cc.pickerWrapper = newPickerWrapper()
cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers) cc.metricsRecorderList = istats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.statsHandler = istats.NewCombinedHandler(cc.dopts.copts.StatsHandlers...)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc. cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout) cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
@@ -260,9 +299,10 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
}() }()
// This creates the name resolver, load balancer, etc. // This creates the name resolver, load balancer, etc.
if err := cc.idlenessMgr.ExitIdleMode(); err != nil { if err := cc.exitIdleMode(); err != nil {
return nil, err return nil, fmt.Errorf("failed to exit idle mode: %w", err)
} }
cc.idlenessMgr.UnsafeSetNotIdle()
// Return now for non-blocking dials. // Return now for non-blocking dials.
if !cc.dopts.block { if !cc.dopts.block {
@@ -330,7 +370,7 @@ func (cc *ClientConn) addTraceEvent(msg string) {
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
} }
} }
channelz.AddTraceEvent(logger, cc.channelz, 0, ted) channelz.AddTraceEvent(logger, cc.channelz, 1, ted)
} }
type idler ClientConn type idler ClientConn
@@ -339,14 +379,17 @@ func (i *idler) EnterIdleMode() {
(*ClientConn)(i).enterIdleMode() (*ClientConn)(i).enterIdleMode()
} }
func (i *idler) ExitIdleMode() error { func (i *idler) ExitIdleMode() {
return (*ClientConn)(i).exitIdleMode() // Ignore the error returned from this method, because from the perspective
// of the caller (idleness manager), the channel would have always moved out
// of IDLE by the time this method returns.
(*ClientConn)(i).exitIdleMode()
} }
// exitIdleMode moves the channel out of idle mode by recreating the name // exitIdleMode moves the channel out of idle mode by recreating the name
// resolver and load balancer. This should never be called directly; use // resolver and load balancer. This should never be called directly; use
// cc.idlenessMgr.ExitIdleMode instead. // cc.idlenessMgr.ExitIdleMode instead.
func (cc *ClientConn) exitIdleMode() (err error) { func (cc *ClientConn) exitIdleMode() error {
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
@@ -354,11 +397,23 @@ func (cc *ClientConn) exitIdleMode() (err error) {
} }
cc.mu.Unlock() cc.mu.Unlock()
// Set state to CONNECTING before building the name resolver
// so the channel does not remain in IDLE.
cc.csMgr.updateState(connectivity.Connecting)
// This needs to be called without cc.mu because this builds a new resolver // This needs to be called without cc.mu because this builds a new resolver
// which might update state or report error inline, which would then need to // which might update state or report error inline, which would then need to
// acquire cc.mu. // acquire cc.mu.
if err := cc.resolverWrapper.start(); err != nil { if err := cc.resolverWrapper.start(); err != nil {
return err // If resolver creation fails, treat it like an error reported by the
// resolver before any valid updates. Set channel's state to
// TransientFailure, and set an erroring picker with the resolver build
// error, which will returned as part of any subsequent RPCs.
logger.Warningf("Failed to start resolver: %v", err)
cc.csMgr.updateState(connectivity.TransientFailure)
cc.mu.Lock()
cc.updateResolverStateAndUnlock(resolver.State{}, err)
return fmt.Errorf("failed to start resolver: %w", err)
} }
cc.addTraceEvent("exiting idle mode") cc.addTraceEvent("exiting idle mode")
@@ -456,7 +511,7 @@ func (cc *ClientConn) validateTransportCredentials() error {
func (cc *ClientConn) channelzRegistration(target string) { func (cc *ClientConn) channelzRegistration(target string) {
parentChannel, _ := cc.dopts.channelzParent.(*channelz.Channel) parentChannel, _ := cc.dopts.channelzParent.(*channelz.Channel)
cc.channelz = channelz.RegisterChannel(parentChannel, target) cc.channelz = channelz.RegisterChannel(parentChannel, target)
cc.addTraceEvent("created") cc.addTraceEvent(fmt.Sprintf("created for target %q", target))
} }
// chainUnaryClientInterceptors chains all unary client interceptors into one. // chainUnaryClientInterceptors chains all unary client interceptors into one.
@@ -621,7 +676,8 @@ type ClientConn struct {
channelz *channelz.Channel // Channelz object. channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder(). resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager idlenessMgr *idle.Manager
metricsRecorderList *stats.MetricsRecorderList metricsRecorderList *istats.MetricsRecorderList
statsHandler stats.Handler
// The following provide their own synchronization, and therefore don't // The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them. // require cc.mu to be held to access them.
@@ -678,10 +734,8 @@ func (cc *ClientConn) GetState() connectivity.State {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later // Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release. // release.
func (cc *ClientConn) Connect() { func (cc *ClientConn) Connect() {
if err := cc.idlenessMgr.ExitIdleMode(); err != nil { cc.idlenessMgr.ExitIdleMode()
cc.addTraceEvent(err.Error())
return
}
// If the ClientConn was not in idle mode, we need to call ExitIdle on the // If the ClientConn was not in idle mode, we need to call ExitIdle on the
// LB policy so that connections can be created. // LB policy so that connections can be created.
cc.mu.Lock() cc.mu.Lock()
@@ -732,8 +786,8 @@ func init() {
internal.EnterIdleModeForTesting = func(cc *ClientConn) { internal.EnterIdleModeForTesting = func(cc *ClientConn) {
cc.idlenessMgr.EnterIdleModeForTesting() cc.idlenessMgr.EnterIdleModeForTesting()
} }
internal.ExitIdleModeForTesting = func(cc *ClientConn) error { internal.ExitIdleModeForTesting = func(cc *ClientConn) {
return cc.idlenessMgr.ExitIdleMode() cc.idlenessMgr.ExitIdleMode()
} }
} }
@@ -858,6 +912,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
channelz: channelz.RegisterSubChannel(cc.channelz, ""), channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}), resetBackoff: make(chan struct{}),
} }
ac.updateTelemetryLabelsLocked()
ac.ctx, ac.cancel = context.WithCancel(cc.ctx) ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if // Start with our address set to the first address; this may be updated if
// we connect to different addresses. // we connect to different addresses.
@@ -922,25 +977,24 @@ func (cc *ClientConn) incrCallsFailed() {
// connect starts creating a transport. // connect starts creating a transport.
// It does nothing if the ac is not IDLE. // It does nothing if the ac is not IDLE.
// TODO(bar) Move this to the addrConn section. // TODO(bar) Move this to the addrConn section.
func (ac *addrConn) connect() error { func (ac *addrConn) connect() {
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown {
if logger.V(2) { if logger.V(2) {
logger.Infof("connect called on shutdown addrConn; ignoring.") logger.Infof("connect called on shutdown addrConn; ignoring.")
} }
ac.mu.Unlock() ac.mu.Unlock()
return errConnClosing return
} }
if ac.state != connectivity.Idle { if ac.state != connectivity.Idle {
if logger.V(2) { if logger.V(2) {
logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state) logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state)
} }
ac.mu.Unlock() ac.mu.Unlock()
return nil return
} }
ac.resetTransportAndUnlock() ac.resetTransportAndUnlock()
return nil
} }
// equalAddressIgnoringBalAttributes returns true is a and b are considered equal. // equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
@@ -974,7 +1028,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
} }
ac.addrs = addrs ac.addrs = addrs
ac.updateTelemetryLabelsLocked()
if ac.state == connectivity.Shutdown || if ac.state == connectivity.Shutdown ||
ac.state == connectivity.TransientFailure || ac.state == connectivity.TransientFailure ||
ac.state == connectivity.Idle { ac.state == connectivity.Idle {
@@ -1076,13 +1130,6 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
return cc.sc.healthCheckConfig return cc.sc.healthCheckConfig
} }
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.pickerWrapper.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx,
FullMethodName: method,
})
}
func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) { func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) {
if sc == nil { if sc == nil {
// should never reach here. // should never reach here.
@@ -1220,6 +1267,9 @@ type addrConn struct {
resetBackoff chan struct{} resetBackoff chan struct{}
channelz *channelz.SubChannel channelz *channelz.SubChannel
localityLabel string
backendServiceLabel string
} }
// Note: this requires a lock on ac.mu. // Note: this requires a lock on ac.mu.
@@ -1227,6 +1277,18 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s { if ac.state == s {
return return
} }
// If we are transitioning out of Ready, it means there is a disconnection.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if ac.state == connectivity.Ready || (ac.state == connectivity.Connecting && s == connectivity.Idle) {
disconnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel, "unknown")
openConnectionsMetric.Record(ac.cc.metricsRecorderList, -1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel)
}
ac.state = s ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s) ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil { if lastErr == nil {
@@ -1234,7 +1296,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else { } else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
} }
ac.acbw.updateState(s, ac.curAddr, lastErr) ac.acbw.updateState(s, lastErr)
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon
@@ -1284,6 +1346,15 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.mu.Unlock() ac.mu.Unlock()
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
if !errors.Is(err, context.Canceled) {
connectionAttemptsFailedMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel)
} else {
if logger.V(2) {
// This records cancelled connection attempts which can be later
// replaced by a metric.
logger.Infof("Context cancellation detected; not recording this as a failed connection attempt.")
}
}
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy // TODO: #7534 - Move re-resolution requests into the pick_first LB policy
// to ensure one resolution request per pass instead of per subconn failure. // to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{}) ac.cc.resolveNow(resolver.ResolveNowOptions{})
@@ -1323,10 +1394,50 @@ func (ac *addrConn) resetTransportAndUnlock() {
} }
// Success; reset backoff. // Success; reset backoff.
ac.mu.Lock() ac.mu.Lock()
connectionAttemptsSucceededMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel)
openConnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel)
ac.backoffIdx = 0 ac.backoffIdx = 0
ac.mu.Unlock() ac.mu.Unlock()
} }
// updateTelemetryLabelsLocked calculates and caches the telemetry labels based on the
// first address in addrConn.
func (ac *addrConn) updateTelemetryLabelsLocked() {
labelsFunc, ok := internal.AddressToTelemetryLabels.(func(resolver.Address) map[string]string)
if !ok || len(ac.addrs) == 0 {
// Reset defaults
ac.localityLabel = ""
ac.backendServiceLabel = ""
return
}
labels := labelsFunc(ac.addrs[0])
ac.localityLabel = labels["grpc.lb.locality"]
ac.backendServiceLabel = labels["grpc.lb.backend_service"]
}
type securityLevelKey struct{}
func (ac *addrConn) securityLevelLocked() string {
var secLevel string
// During disconnection, ac.transport is nil. Fall back to the security level
// stored in the current address during connection.
if ac.transport == nil {
secLevel, _ = ac.curAddr.Attributes.Value(securityLevelKey{}).(string)
return secLevel
}
authInfo := ac.transport.Peer().AuthInfo
if ci, ok := authInfo.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
secLevel = ci.GetCommonAuthInfo().SecurityLevel.String()
// Store the security level in the current address' attributes so
// that it remains available for disconnection metrics after the
// transport is closed.
ac.curAddr.Attributes = ac.curAddr.Attributes.WithValue(securityLevelKey{}, secLevel)
}
return secLevel
}
// tryAllAddrs tries to create a connection to the addresses, and stop when at // tryAllAddrs tries to create a connection to the addresses, and stop when at
// the first successful one. It returns an error if no address was successfully // the first successful one. It returns an error if no address was successfully
// connected, or updates ac appropriately with the new transport. // connected, or updates ac appropriately with the new transport.
@@ -1416,25 +1527,26 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
} }
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock()
if ctx.Err() != nil { if ctx.Err() != nil {
// This can happen if the subConn was removed while in `Connecting` // This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but // state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not // would not have closed the transport since ac.transport would not
// have been set at that point. // have been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose() // We unlock ac.mu because newTr.Close() calls onClose()
// inline, which requires locking ac.mu. // inline, which requires locking ac.mu.
// ac.mu.Unlock()
// The error we pass to Close() is immaterial since there are no open // The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent // streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error. // out. We just need to pass a non-nil error.
// //
// This can also happen when updateAddrs is called during a connection // This can also happen when updateAddrs is called during a connection
// attempt. // attempt.
go newTr.Close(transport.ErrConnClosing) newTr.Close(transport.ErrConnClosing)
return nil return nil
} }
defer ac.mu.Unlock()
if hctx.Err() != nil { if hctx.Err() != nil {
// onClose was already called for this connection, but the connection // onClose was already called for this connection, but the connection
// was successfully established first. Consider it a success and set // was successfully established first. Consider it a success and set
@@ -1831,7 +1943,7 @@ func (cc *ClientConn) initAuthority() error {
} else if auth, ok := cc.resolverBuilder.(resolver.AuthorityOverrider); ok { } else if auth, ok := cc.resolverBuilder.(resolver.AuthorityOverrider); ok {
cc.authority = auth.OverrideAuthority(cc.parsedTarget) cc.authority = auth.OverrideAuthority(cc.parsedTarget)
} else if strings.HasPrefix(endpoint, ":") { } else if strings.HasPrefix(endpoint, ":") {
cc.authority = "localhost" + endpoint cc.authority = "localhost" + encodeAuthority(endpoint)
} else { } else {
cc.authority = encodeAuthority(endpoint) cc.authority = encodeAuthority(endpoint)
} }

View File

@@ -44,8 +44,7 @@ type PerRPCCredentials interface {
// A54). uri is the URI of the entry point for the request. When supported // A54). uri is the URI of the entry point for the request. When supported
// by the underlying implementation, ctx can be used for timeout and // by the underlying implementation, ctx can be used for timeout and
// cancellation. Additionally, RequestInfo data will be available via ctx // cancellation. Additionally, RequestInfo data will be available via ctx
// to this call. TODO(zhaoq): Define the set of the qualified keys instead // to this call.
// of leaving it as an arbitrary string.
GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
// RequireTransportSecurity indicates whether the credentials requires // RequireTransportSecurity indicates whether the credentials requires
// transport security. // transport security.
@@ -96,10 +95,11 @@ func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
return c return c
} }
// ProtocolInfo provides information regarding the gRPC wire protocol version, // ProtocolInfo provides static information regarding transport credentials.
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct { type ProtocolInfo struct {
// ProtocolVersion is the gRPC wire protocol version. // ProtocolVersion is the gRPC wire protocol version.
//
// Deprecated: this is unused by gRPC.
ProtocolVersion string ProtocolVersion string
// SecurityProtocol is the security protocol in use. // SecurityProtocol is the security protocol in use.
SecurityProtocol string SecurityProtocol string
@@ -109,7 +109,16 @@ type ProtocolInfo struct {
// //
// Deprecated: please use Peer.AuthInfo. // Deprecated: please use Peer.AuthInfo.
SecurityVersion string SecurityVersion string
// ServerName is the user-configured server name. // ServerName is the user-configured server name. If set, this overrides
// the default :authority header used for all RPCs on the channel using the
// containing credentials, unless grpc.WithAuthority is set on the channel,
// in which case that setting will take precedence.
//
// This must be a valid `:authority` header according to
// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2).
//
// Deprecated: Users should use grpc.WithAuthority to override the authority
// on a channel instead of configuring the credentials.
ServerName string ServerName string
} }
@@ -173,12 +182,17 @@ type TransportCredentials interface {
// Clone makes a copy of this TransportCredentials. // Clone makes a copy of this TransportCredentials.
Clone() TransportCredentials Clone() TransportCredentials
// OverrideServerName specifies the value used for the following: // OverrideServerName specifies the value used for the following:
//
// - verifying the hostname on the returned certificates // - verifying the hostname on the returned certificates
// - as SNI in the client's handshake to support virtual hosting // - as SNI in the client's handshake to support virtual hosting
// - as the value for `:authority` header at stream creation time // - as the value for `:authority` header at stream creation time
// //
// Deprecated: use grpc.WithAuthority instead. Will be supported // The provided string should be a valid `:authority` header according to
// throughout 1.x. // [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2).
//
// Deprecated: this method is unused by gRPC. Users should use
// grpc.WithAuthority to override the authority on a channel instead of
// configuring the credentials.
OverrideServerName(string) error OverrideServerName(string) error
} }

View File

@@ -56,9 +56,13 @@ func (t TLSInfo) AuthType() string {
// non-nil error if the validation fails. // non-nil error if the validation fails.
func (t TLSInfo) ValidateAuthority(authority string) error { func (t TLSInfo) ValidateAuthority(authority string) error {
var errs []error var errs []error
host, _, err := net.SplitHostPort(authority)
if err != nil {
host = authority
}
for _, cert := range t.State.PeerCertificates { for _, cert := range t.State.PeerCertificates {
var err error var err error
if err = cert.VerifyHostname(authority); err == nil { if err = cert.VerifyHostname(host); err == nil {
return nil return nil
} }
errs = append(errs, err) errs = append(errs, err)
@@ -110,14 +114,14 @@ func (c tlsCreds) Info() ProtocolInfo {
func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints // use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := credinternal.CloneTLSConfig(c.config) cfg := credinternal.CloneTLSConfig(c.config)
if cfg.ServerName == "" {
serverName, _, err := net.SplitHostPort(authority) serverName, _, err := net.SplitHostPort(authority)
if err != nil { if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is. // If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority serverName = authority
} }
cfg.ServerName = serverName cfg.ServerName = serverName
}
conn := tls.Client(rawConn, cfg) conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1) errChannel := make(chan error, 1)
go func() { go func() {
@@ -259,9 +263,11 @@ func applyDefaults(c *tls.Config) *tls.Config {
// certificates to establish the identity of the client need to be included in // certificates to establish the identity of the client need to be included in
// the credentials (eg: for mTLS), use NewTLS instead, where a complete // the credentials (eg: for mTLS), use NewTLS instead, where a complete
// tls.Config can be specified. // tls.Config can be specified.
// serverNameOverride is for testing only. If set to a non empty string, //
// it will override the virtual host name of authority (e.g. :authority header // serverNameOverride is for testing only. If set to a non empty string, it will
// field) in requests. // override the virtual host name of authority (e.g. :authority header field) in
// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to
// override the authority of the client instead.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials { func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}) return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
} }
@@ -271,9 +277,11 @@ func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) Transpor
// certificates to establish the identity of the client need to be included in // certificates to establish the identity of the client need to be included in
// the credentials (eg: for mTLS), use NewTLS instead, where a complete // the credentials (eg: for mTLS), use NewTLS instead, where a complete
// tls.Config can be specified. // tls.Config can be specified.
// serverNameOverride is for testing only. If set to a non empty string, //
// it will override the virtual host name of authority (e.g. :authority header // serverNameOverride is for testing only. If set to a non empty string, it will
// field) in requests. // override the virtual host name of authority (e.g. :authority header field) in
// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to
// override the authority of the client instead.
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) { func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
b, err := os.ReadFile(certFile) b, err := os.ReadFile(certFile)
if err != nil { if err != nil {

View File

@@ -608,6 +608,8 @@ func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOpt
// WithAuthority returns a DialOption that specifies the value to be used as the // WithAuthority returns a DialOption that specifies the value to be used as the
// :authority pseudo-header and as the server name in authentication handshake. // :authority pseudo-header and as the server name in authentication handshake.
// This overrides all other ways of setting authority on the channel, but can be
// overridden per-call by using grpc.CallAuthority.
func WithAuthority(a string) DialOption { func WithAuthority(a string) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.authority = a o.authority = a

View File

@@ -27,8 +27,10 @@ package encoding
import ( import (
"io" "io"
"slices"
"strings" "strings"
"google.golang.org/grpc/encoding/internal"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
) )
@@ -36,12 +38,26 @@ import (
// It is intended for grpc internal use only. // It is intended for grpc internal use only.
const Identity = "identity" const Identity = "identity"
func init() {
internal.RegisterCompressorForTesting = func(c Compressor) func() {
name := c.Name()
curCompressor, found := registeredCompressor[name]
RegisterCompressor(c)
return func() {
if found {
registeredCompressor[name] = curCompressor
return
}
delete(registeredCompressor, name)
grpcutil.RegisteredCompressorNames = slices.DeleteFunc(grpcutil.RegisteredCompressorNames, func(s string) bool {
return s == name
})
}
}
}
// Compressor is used for compressing and decompressing when sending or // Compressor is used for compressing and decompressing when sending or
// receiving messages. // receiving messages.
//
// If a Compressor implements `DecompressedSize(compressedBytes []byte) int`,
// gRPC will invoke it to determine the size of the buffer allocated for the
// result of decompression. A return value of -1 indicates unknown size.
type Compressor interface { type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an // Compress writes the data written to wc to w after compressing it. If an
// error occurs while initializing the compressor, that error is returned // error occurs while initializing the compressor, that error is returned

View File

@@ -0,0 +1,28 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains code internal to the encoding package.
package internal
// RegisterCompressorForTesting registers a compressor in the global compressor
// registry. It returns a cleanup function that should be called at the end
// of the test to unregister the compressor.
//
// This prevents compressors registered in one test from appearing in the
// encoding headers of subsequent tests.
var RegisterCompressorForTesting any // func RegisterCompressor(c Compressor) func()

View File

@@ -46,9 +46,25 @@ func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v) return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
} }
// Important: if we remove this Size call then we cannot use
// UseCachedSize in MarshalOptions below.
size := proto.Size(vv) size := proto.Size(vv)
// MarshalOptions with UseCachedSize allows reusing the result from the
// previous Size call. This is safe here because:
//
// 1. We just computed the size.
// 2. We assume the message is not being mutated concurrently.
//
// Important: If the proto.Size call above is removed, using UseCachedSize
// becomes unsafe and may lead to incorrect marshaling.
//
// For more details, see the doc of UseCachedSize:
// https://pkg.go.dev/google.golang.org/protobuf/proto#MarshalOptions
marshalOptions := proto.MarshalOptions{UseCachedSize: true}
if mem.IsBelowBufferPoolingThreshold(size) { if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv) buf, err := marshalOptions.Marshal(vv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -56,7 +72,7 @@ func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
} else { } else {
pool := mem.DefaultBufferPool() pool := mem.DefaultBufferPool()
buf := pool.Get(size) buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil { if _, err := marshalOptions.MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf) pool.Put(buf)
return nil, err return nil, err
} }

View File

@@ -75,6 +75,8 @@ const (
MetricTypeIntHisto MetricTypeIntHisto
MetricTypeFloatHisto MetricTypeFloatHisto
MetricTypeIntGauge MetricTypeIntGauge
MetricTypeIntUpDownCount
MetricTypeIntAsyncGauge
) )
// Int64CountHandle is a typed handle for a int count metric. This handle // Int64CountHandle is a typed handle for a int count metric. This handle
@@ -93,6 +95,23 @@ func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels .
recorder.RecordInt64Count(h, incr, labels...) recorder.RecordInt64Count(h, incr, labels...)
} }
// Int64UpDownCountHandle is a typed handle for an int up-down counter metric.
// This handle is passed at the recording point in order to know which metric
// to record on.
type Int64UpDownCountHandle MetricDescriptor
// Descriptor returns the int64 up-down counter handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64UpDownCountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 up-down counter value on the metrics recorder provided.
// The value 'v' can be positive to increment or negative to decrement.
func (h *Int64UpDownCountHandle) Record(recorder MetricsRecorder, v int64, labels ...string) {
recorder.RecordInt64UpDownCount(h, v, labels...)
}
// Float64CountHandle is a typed handle for a float count metric. This handle is // Float64CountHandle is a typed handle for a float count metric. This handle is
// passed at the recording point in order to know which metric to record on. // passed at the recording point in order to know which metric to record on.
type Float64CountHandle MetricDescriptor type Float64CountHandle MetricDescriptor
@@ -154,6 +173,30 @@ func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels .
recorder.RecordInt64Gauge(h, incr, labels...) recorder.RecordInt64Gauge(h, incr, labels...)
} }
// AsyncMetric is a marker interface for asynchronous metric types.
type AsyncMetric interface {
isAsync()
Descriptor() *MetricDescriptor
}
// Int64AsyncGaugeHandle is a typed handle for an int gauge metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Int64AsyncGaugeHandle MetricDescriptor
// isAsync implements the AsyncMetric interface.
func (h *Int64AsyncGaugeHandle) isAsync() {}
// Descriptor returns the int64 gauge handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64AsyncGaugeHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 gauge value on the metrics recorder provided.
func (h *Int64AsyncGaugeHandle) Record(recorder AsyncMetricsRecorder, value int64, labels ...string) {
recorder.RecordInt64AsyncGauge(h, value, labels...)
}
// registeredMetrics are the registered metric descriptor names. // registeredMetrics are the registered metric descriptor names.
var registeredMetrics = make(map[string]bool) var registeredMetrics = make(map[string]bool)
@@ -249,6 +292,35 @@ func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle {
return (*Int64GaugeHandle)(descPtr) return (*Int64GaugeHandle)(descPtr)
} }
// RegisterInt64UpDownCount registers the metric description onto the global registry.
// It returns a typed handle to use for recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64UpDownCount(descriptor MetricDescriptor) *Int64UpDownCountHandle {
registerMetric(descriptor.Name, descriptor.Default)
// Set the specific metric type for the up-down counter
descriptor.Type = MetricTypeIntUpDownCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64UpDownCountHandle)(descPtr)
}
// RegisterInt64AsyncGauge registers the metric description onto the global registry.
// It returns a typed handle to use for recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64AsyncGauge(descriptor MetricDescriptor) *Int64AsyncGaugeHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntAsyncGauge
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64AsyncGaugeHandle)(descPtr)
}
// snapshotMetricsRegistryForTesting snapshots the global data of the metrics // snapshotMetricsRegistryForTesting snapshots the global data of the metrics
// registry. Returns a cleanup function that sets the metrics registry to its // registry. Returns a cleanup function that sets the metrics registry to its
// original state. // original state.

View File

@@ -19,9 +19,13 @@
// Package stats contains experimental metrics/stats API's. // Package stats contains experimental metrics/stats API's.
package stats package stats
import "google.golang.org/grpc/stats" import (
"google.golang.org/grpc/internal"
"google.golang.org/grpc/stats"
)
// MetricsRecorder records on metrics derived from metric registry. // MetricsRecorder records on metrics derived from metric registry.
// Implementors must embed UnimplementedMetricsRecorder.
type MetricsRecorder interface { type MetricsRecorder interface {
// RecordInt64Count records the measurement alongside labels on the int // RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle. // count associated with the provided handle.
@@ -38,6 +42,49 @@ type MetricsRecorder interface {
// RecordInt64Gauge records the measurement alongside labels on the int // RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle. // gauge associated with the provided handle.
RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string) RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string)
// RecordInt64UpDownCounter records the measurement alongside labels on the int
// count associated with the provided handle.
RecordInt64UpDownCount(handle *Int64UpDownCountHandle, incr int64, labels ...string)
// RegisterAsyncReporter registers a reporter to produce metric values for
// only the listed descriptors. The returned function must be called when
// the metrics are no longer needed, which will remove the reporter. The
// returned method needs to be idempotent and concurrent safe.
RegisterAsyncReporter(reporter AsyncMetricReporter, descriptors ...AsyncMetric) func()
// EnforceMetricsRecorderEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceMetricsRecorderEmbedding
}
// AsyncMetricReporter is an interface for types that record metrics asynchronously
// for the set of descriptors they are registered with. The AsyncMetricsRecorder
// parameter is used to record values for these metrics.
//
// Implementations must make unique recordings across all registered
// AsyncMetricReporters. Meaning, they should not report values for a metric with
// the same attributes as another AsyncMetricReporter will report.
//
// Implementations must be concurrent-safe.
type AsyncMetricReporter interface {
// Report records metric values using the provided recorder.
Report(AsyncMetricsRecorder) error
}
// AsyncMetricReporterFunc is an adapter to allow the use of ordinary functions as
// AsyncMetricReporters.
type AsyncMetricReporterFunc func(AsyncMetricsRecorder) error
// Report calls f(r).
func (f AsyncMetricReporterFunc) Report(r AsyncMetricsRecorder) error {
return f(r)
}
// AsyncMetricsRecorder records on asynchronous metrics derived from metric registry.
type AsyncMetricsRecorder interface {
// RecordInt64AsyncGauge records the measurement alongside labels on the int
// count associated with the provided handle asynchronously
RecordInt64AsyncGauge(handle *Int64AsyncGaugeHandle, incr int64, labels ...string)
} }
// Metrics is an experimental legacy alias of the now-stable stats.MetricSet. // Metrics is an experimental legacy alias of the now-stable stats.MetricSet.
@@ -52,3 +99,33 @@ type Metric = string
func NewMetrics(metrics ...Metric) *Metrics { func NewMetrics(metrics ...Metric) *Metrics {
return stats.NewMetricSet(metrics...) return stats.NewMetricSet(metrics...)
} }
// UnimplementedMetricsRecorder must be embedded to have forward compatible implementations.
type UnimplementedMetricsRecorder struct {
internal.EnforceMetricsRecorderEmbedding
}
// RecordInt64Count provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Count(*Int64CountHandle, int64, ...string) {}
// RecordFloat64Count provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordFloat64Count(*Float64CountHandle, float64, ...string) {}
// RecordInt64Histo provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Histo(*Int64HistoHandle, int64, ...string) {}
// RecordFloat64Histo provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordFloat64Histo(*Float64HistoHandle, float64, ...string) {}
// RecordInt64Gauge provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Gauge(*Int64GaugeHandle, int64, ...string) {}
// RecordInt64UpDownCount provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64UpDownCount(*Int64UpDownCountHandle, int64, ...string) {
}
// RegisterAsyncReporter provides a no-op implementation.
func (UnimplementedMetricsRecorder) RegisterAsyncReporter(AsyncMetricReporter, ...AsyncMetric) func() {
// No-op: Return an empty function to ensure caller doesn't panic on nil function call
return func() {}
}

View File

@@ -97,8 +97,12 @@ type StreamServerInfo struct {
IsServerStream bool IsServerStream bool
} }
// StreamServerInterceptor provides a hook to intercept the execution of a streaming RPC on the server. // StreamServerInterceptor provides a hook to intercept the execution of a
// info contains all the information of this RPC the interceptor can operate on. And handler is the // streaming RPC on the server.
// service method implementation. It is the responsibility of the interceptor to invoke handler to //
// complete the RPC. // srv is the service implementation on which the RPC was invoked, and needs to
// be passed to handler, and not used otherwise. ss is the server side of the
// stream. info contains all the information of this RPC the interceptor can
// operate on. And handler is the service method implementation. It is the
// responsibility of the interceptor to invoke handler to complete the RPC.
type StreamServerInterceptor func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error type StreamServerInterceptor func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

View File

@@ -67,6 +67,10 @@ type Balancer struct {
// balancerCurrent before the UpdateSubConnState is called on the // balancerCurrent before the UpdateSubConnState is called on the
// balancerCurrent. // balancerCurrent.
currentMu sync.Mutex currentMu sync.Mutex
// activeGoroutines tracks all the goroutines that this balancer has started
// and that should be waited on when the balancer closes.
activeGoroutines sync.WaitGroup
} }
// swap swaps out the current lb with the pending lb and updates the ClientConn. // swap swaps out the current lb with the pending lb and updates the ClientConn.
@@ -76,7 +80,9 @@ func (gsb *Balancer) swap() {
cur := gsb.balancerCurrent cur := gsb.balancerCurrent
gsb.balancerCurrent = gsb.balancerPending gsb.balancerCurrent = gsb.balancerPending
gsb.balancerPending = nil gsb.balancerPending = nil
gsb.activeGoroutines.Add(1)
go func() { go func() {
defer gsb.activeGoroutines.Done()
gsb.currentMu.Lock() gsb.currentMu.Lock()
defer gsb.currentMu.Unlock() defer gsb.currentMu.Unlock()
cur.Close() cur.Close()
@@ -274,6 +280,7 @@ func (gsb *Balancer) Close() {
currentBalancerToClose.Close() currentBalancerToClose.Close()
pendingBalancerToClose.Close() pendingBalancerToClose.Close()
gsb.activeGoroutines.Wait()
} }
// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer // balancerWrapper wraps a balancer.Balancer, and overrides some Balancer
@@ -324,7 +331,12 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) {
defer bw.gsb.mu.Unlock() defer bw.gsb.mu.Unlock()
bw.lastState = state bw.lastState = state
// If Close() acquires the mutex before UpdateState(), the balancer
// will already have been removed from the current or pending state when
// reaching this point.
if !bw.gsb.balancerCurrentOrPending(bw) { if !bw.gsb.balancerCurrentOrPending(bw) {
// Returning here ensures that (*Balancer).swap() is not invoked after
// (*Balancer).Close() and therefore prevents "use after close".
return return
} }

View File

@@ -0,0 +1,66 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package weight contains utilities to manage endpoint weights. Weights are
// used by LB policies such as ringhash to distribute load across multiple
// endpoints.
package weight
import (
"fmt"
"google.golang.org/grpc/resolver"
)
// attributeKey is the type used as the key to store EndpointInfo in the
// Attributes field of resolver.Endpoint.
type attributeKey struct{}
// EndpointInfo will be stored in the Attributes field of Endpoints in order to
// use the ringhash balancer.
type EndpointInfo struct {
Weight uint32
}
// Equal allows the values to be compared by Attributes.Equal.
func (a EndpointInfo) Equal(o any) bool {
oa, ok := o.(EndpointInfo)
return ok && oa.Weight == a.Weight
}
// Set returns a copy of endpoint in which the Attributes field is updated with
// EndpointInfo.
func Set(endpoint resolver.Endpoint, epInfo EndpointInfo) resolver.Endpoint {
endpoint.Attributes = endpoint.Attributes.WithValue(attributeKey{}, epInfo)
return endpoint
}
// String returns a human-readable representation of EndpointInfo.
// This method is intended for logging, testing, and debugging purposes only.
// Do not rely on the output format, as it is not guaranteed to remain stable.
func (a EndpointInfo) String() string {
return fmt.Sprintf("Weight: %d", a.Weight)
}
// FromEndpoint returns the EndpointInfo stored in the Attributes field of an
// endpoint. It returns an empty EndpointInfo if attribute is not found.
func FromEndpoint(endpoint resolver.Endpoint) EndpointInfo {
v := endpoint.Attributes.Value(attributeKey{})
ei, _ := v.(EndpointInfo)
return ei
}

View File

@@ -83,6 +83,7 @@ func (b *Unbounded) Load() {
default: default:
} }
} else if b.closing && !b.closed { } else if b.closing && !b.closed {
b.closed = true
close(b.c) close(b.c)
} }
} }

View File

@@ -26,31 +26,31 @@ import (
) )
var ( var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). // EnableTXTServiceConfig is set if the DNS resolver should perform TXT
// lookups for service config ("GRPC_ENABLE_TXT_SERVICE_CONFIG" is not
// "false").
EnableTXTServiceConfig = boolFromEnv("GRPC_ENABLE_TXT_SERVICE_CONFIG", true)
// TXTErrIgnore is set if TXT errors should be ignored
// ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true) TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true)
// RingHashCap indicates the maximum ring size which defaults to 4096 // RingHashCap indicates the maximum ring size which defaults to 4096
// entries but may be overridden by setting the environment variable // entries but may be overridden by setting the environment variable
// "GRPC_RING_HASH_CAP". This does not override the default bounds // "GRPC_RING_HASH_CAP". This does not override the default bounds
// checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M). // checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M).
RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024) RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024)
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS // ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed. // handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled // EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this // should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden // option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true" // by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false". // or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true) EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true)
// XDSFallbackSupport is the env variable that controls whether support for
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", true)
// NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
// instead of the exiting pickfirst implementation. This can be disabled by
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"
// to "false".
NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true)
// XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash // XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash
// key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by // key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by
@@ -69,6 +69,41 @@ var (
// ALTSHandshakerKeepaliveParams is set if we should add the // ALTSHandshakerKeepaliveParams is set if we should add the
// KeepaliveParams when dial the ALTS handshaker service. // KeepaliveParams when dial the ALTS handshaker service.
ALTSHandshakerKeepaliveParams = boolFromEnv("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false) ALTSHandshakerKeepaliveParams = boolFromEnv("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false)
// EnableDefaultPortForProxyTarget controls whether the resolver adds a default port 443
// to a target address that lacks one. This flag only has an effect when all of
// the following conditions are met:
// - A connect proxy is being used.
// - Target resolution is disabled.
// - The DNS resolver is being used.
EnableDefaultPortForProxyTarget = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_DEFAULT_PORT_FOR_PROXY_TARGET", true)
// XDSAuthorityRewrite indicates whether xDS authority rewriting is enabled.
// This feature is defined in gRFC A81 and is enabled by setting the
// environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true".
XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)
// PickFirstWeightedShuffling indicates whether weighted endpoint shuffling
// is enabled in the pick_first LB policy, as defined in gRFC A113. This
// feature can be disabled by setting the environment variable
// GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false".
PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true)
// DisableStrictPathChecking indicates whether strict path checking is
// disabled. This feature can be disabled by setting the environment
// variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to "true".
//
// When strict path checking is enabled, gRPC will reject requests with
// paths that do not conform to the gRPC over HTTP/2 specification found at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md.
//
// When disabled, gRPC will allow paths that do not contain a leading slash.
// Enabling strict path checking is recommended for security reasons, as it
// prevents potential path traversal vulnerabilities.
//
// A future release will remove this environment variable, enabling strict
// path checking behavior unconditionally.
DisableStrictPathChecking = boolFromEnv("GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING", false)
) )
func boolFromEnv(envVar string, def bool) bool { func boolFromEnv(envVar string, def bool) bool {

View File

@@ -68,4 +68,15 @@ var (
// trust. For more details, see: // trust. For more details, see:
// https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md // https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md
XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false) XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false)
// XDSHTTPConnectEnabled is true if gRPC should parse custom Metadata
// configuring use of an HTTP CONNECT proxy via xDS from cluster resources.
// For more details, see:
// https://github.com/grpc/proposal/blob/master/A86-xds-http-connect.md
XDSHTTPConnectEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false)
// XDSBootstrapCallCredsEnabled controls if call credentials can be used in
// xDS bootstrap configuration via the `call_creds` field. For more details,
// see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md
XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false)
) )

View File

@@ -25,4 +25,11 @@ var (
// BufferPool is implemented by the grpc package and returns a server // BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server. // option to configure a shared buffer pool for a grpc.Server.
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
// SetDefaultBufferPool updates the default buffer pool.
SetDefaultBufferPool any // func(mem.BufferPool)
// AcceptCompressors is implemented by the grpc package and returns
// a call option that restricts the grpc-accept-encoding header for a call.
AcceptCompressors any // func(...string) grpc.CallOption
) )

View File

@@ -80,25 +80,11 @@ func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure
func (cs *CallbackSerializer) run(ctx context.Context) { func (cs *CallbackSerializer) run(ctx context.Context) {
defer close(cs.done) defer close(cs.done)
// TODO: when Go 1.21 is the oldest supported version, this loop and Close // Close the buffer when the context is canceled
// can be replaced with: // to prevent new callbacks from being added.
// context.AfterFunc(ctx, cs.callbacks.Close)
// context.AfterFunc(ctx, cs.callbacks.Close)
for ctx.Err() == nil {
select {
case <-ctx.Done():
// Do nothing here. Next iteration of the for loop will not happen,
// since ctx.Err() would be non-nil.
case cb := <-cs.callbacks.Get():
cs.callbacks.Load()
cb.(func(context.Context))(ctx)
}
}
// Close the buffer to prevent new callbacks from being added. // Run all callbacks.
cs.callbacks.Close()
// Run all pending callbacks.
for cb := range cs.callbacks.Get() { for cb := range cs.callbacks.Get() {
cs.callbacks.Load() cs.callbacks.Load()
cb.(func(context.Context))(ctx) cb.(func(context.Context))(ctx)

View File

@@ -21,7 +21,6 @@
package idle package idle
import ( import (
"fmt"
"math" "math"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -33,15 +32,15 @@ var timeAfterFunc = func(d time.Duration, f func()) *time.Timer {
return time.AfterFunc(d, f) return time.AfterFunc(d, f)
} }
// Enforcer is the functionality provided by grpc.ClientConn to enter // ClientConn is the functionality provided by grpc.ClientConn to enter and exit
// and exit from idle mode. // from idle mode.
type Enforcer interface { type ClientConn interface {
ExitIdleMode() error ExitIdleMode()
EnterIdleMode() EnterIdleMode()
} }
// Manager implements idleness detection and calls the configured Enforcer to // Manager implements idleness detection and calls the ClientConn to enter/exit
// enter/exit idle mode when appropriate. Must be created by NewManager. // idle mode when appropriate. Must be created by NewManager.
type Manager struct { type Manager struct {
// State accessed atomically. // State accessed atomically.
lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed. lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed.
@@ -51,7 +50,7 @@ type Manager struct {
// Can be accessed without atomics or mutex since these are set at creation // Can be accessed without atomics or mutex since these are set at creation
// time and read-only after that. // time and read-only after that.
enforcer Enforcer // Functionality provided by grpc.ClientConn. cc ClientConn // Functionality provided by grpc.ClientConn.
timeout time.Duration timeout time.Duration
// idleMu is used to guarantee mutual exclusion in two scenarios: // idleMu is used to guarantee mutual exclusion in two scenarios:
@@ -72,9 +71,9 @@ type Manager struct {
// NewManager creates a new idleness manager implementation for the // NewManager creates a new idleness manager implementation for the
// given idle timeout. It begins in idle mode. // given idle timeout. It begins in idle mode.
func NewManager(enforcer Enforcer, timeout time.Duration) *Manager { func NewManager(cc ClientConn, timeout time.Duration) *Manager {
return &Manager{ return &Manager{
enforcer: enforcer, cc: cc,
timeout: timeout, timeout: timeout,
actuallyIdle: true, actuallyIdle: true,
activeCallsCount: -math.MaxInt32, activeCallsCount: -math.MaxInt32,
@@ -127,7 +126,7 @@ func (m *Manager) handleIdleTimeout() {
// Now that we've checked that there has been no activity, attempt to enter // Now that we've checked that there has been no activity, attempt to enter
// idle mode, which is very likely to succeed. // idle mode, which is very likely to succeed.
if m.tryEnterIdleMode() { if m.tryEnterIdleMode(true) {
// Successfully entered idle mode. No timer needed until we exit idle. // Successfully entered idle mode. No timer needed until we exit idle.
return return
} }
@@ -142,10 +141,13 @@ func (m *Manager) handleIdleTimeout() {
// that, it performs a last minute check to ensure that no new RPC has come in, // that, it performs a last minute check to ensure that no new RPC has come in,
// making the channel active. // making the channel active.
// //
// checkActivity controls if a check for RPC activity, since the last time the
// idle_timeout fired, is made.
// Return value indicates whether or not the channel moved to idle mode. // Return value indicates whether or not the channel moved to idle mode.
// //
// Holds idleMu which ensures mutual exclusion with exitIdleMode. // Holds idleMu which ensures mutual exclusion with exitIdleMode.
func (m *Manager) tryEnterIdleMode() bool { func (m *Manager) tryEnterIdleMode(checkActivity bool) bool {
// Setting the activeCallsCount to -math.MaxInt32 indicates to OnCallBegin() // Setting the activeCallsCount to -math.MaxInt32 indicates to OnCallBegin()
// that the channel is either in idle mode or is trying to get there. // that the channel is either in idle mode or is trying to get there.
if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) { if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) {
@@ -166,7 +168,7 @@ func (m *Manager) tryEnterIdleMode() bool {
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
return false return false
} }
if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { if checkActivity && atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 {
// A very short RPC could have come in (and also finished) after we // A very short RPC could have come in (and also finished) after we
// checked for calls count and activity in handleIdleTimeout(), but // checked for calls count and activity in handleIdleTimeout(), but
// before the CAS operation. So, we need to check for activity again. // before the CAS operation. So, we need to check for activity again.
@@ -177,44 +179,37 @@ func (m *Manager) tryEnterIdleMode() bool {
// No new RPCs have come in since we set the active calls count value to // No new RPCs have come in since we set the active calls count value to
// -math.MaxInt32. And since we have the lock, it is safe to enter idle mode // -math.MaxInt32. And since we have the lock, it is safe to enter idle mode
// unconditionally now. // unconditionally now.
m.enforcer.EnterIdleMode() m.cc.EnterIdleMode()
m.actuallyIdle = true m.actuallyIdle = true
return true return true
} }
// EnterIdleModeForTesting instructs the channel to enter idle mode. // EnterIdleModeForTesting instructs the channel to enter idle mode.
func (m *Manager) EnterIdleModeForTesting() { func (m *Manager) EnterIdleModeForTesting() {
m.tryEnterIdleMode() m.tryEnterIdleMode(false)
} }
// OnCallBegin is invoked at the start of every RPC. // OnCallBegin is invoked at the start of every RPC.
func (m *Manager) OnCallBegin() error { func (m *Manager) OnCallBegin() {
if m.isClosed() { if m.isClosed() {
return nil return
} }
if atomic.AddInt32(&m.activeCallsCount, 1) > 0 { if atomic.AddInt32(&m.activeCallsCount, 1) > 0 {
// Channel is not idle now. Set the activity bit and allow the call. // Channel is not idle now. Set the activity bit and allow the call.
atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1)
return nil return
} }
// Channel is either in idle mode or is in the process of moving to idle // Channel is either in idle mode or is in the process of moving to idle
// mode. Attempt to exit idle mode to allow this RPC. // mode. Attempt to exit idle mode to allow this RPC.
if err := m.ExitIdleMode(); err != nil { m.ExitIdleMode()
// Undo the increment to calls count, and return an error causing the
// RPC to fail.
atomic.AddInt32(&m.activeCallsCount, -1)
return err
}
atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1)
return nil
} }
// ExitIdleMode instructs m to call the enforcer's ExitIdleMode and update m's // ExitIdleMode instructs m to call the ClientConn's ExitIdleMode and update its
// internal state. // internal state.
func (m *Manager) ExitIdleMode() error { func (m *Manager) ExitIdleMode() {
// Holds idleMu which ensures mutual exclusion with tryEnterIdleMode. // Holds idleMu which ensures mutual exclusion with tryEnterIdleMode.
m.idleMu.Lock() m.idleMu.Lock()
defer m.idleMu.Unlock() defer m.idleMu.Unlock()
@@ -231,12 +226,10 @@ func (m *Manager) ExitIdleMode() error {
// m.ExitIdleMode. // m.ExitIdleMode.
// //
// In any case, there is nothing to do here. // In any case, there is nothing to do here.
return nil return
} }
if err := m.enforcer.ExitIdleMode(); err != nil { m.cc.ExitIdleMode()
return fmt.Errorf("failed to exit idle mode: %w", err)
}
// Undo the idle entry process. This also respects any new RPC attempts. // Undo the idle entry process. This also respects any new RPC attempts.
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
@@ -244,7 +237,23 @@ func (m *Manager) ExitIdleMode() error {
// Start a new timer to fire after the configured idle timeout. // Start a new timer to fire after the configured idle timeout.
m.resetIdleTimerLocked(m.timeout) m.resetIdleTimerLocked(m.timeout)
return nil }
// UnsafeSetNotIdle instructs the Manager to update its internal state to
// reflect the reality that the channel is no longer in IDLE mode.
//
// N.B. This method is intended only for internal use by the gRPC client
// when it exits IDLE mode **manually** from `Dial`. The callsite must ensure:
// - The channel was **actually in IDLE mode** immediately prior to the call.
// - There is **no concurrent activity** that could cause the channel to exit
// IDLE mode *naturally* at the same time.
func (m *Manager) UnsafeSetNotIdle() {
m.idleMu.Lock()
defer m.idleMu.Unlock()
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
m.actuallyIdle = false
m.resetIdleTimerLocked(m.timeout)
} }
// OnCallEnd is invoked at the end of every RPC. // OnCallEnd is invoked at the end of every RPC.

View File

@@ -182,35 +182,6 @@ var (
// other features, including the CSDS service. // other features, including the CSDS service.
NewXDSResolverWithClientForTesting any // func(xdsclient.XDSClient) (resolver.Builder, error) NewXDSResolverWithClientForTesting any // func(xdsclient.XDSClient) (resolver.Builder, error)
// RegisterRLSClusterSpecifierPluginForTesting registers the RLS Cluster
// Specifier Plugin for testing purposes, regardless of the XDSRLS environment
// variable.
//
// TODO: Remove this function once the RLS env var is removed.
RegisterRLSClusterSpecifierPluginForTesting func()
// UnregisterRLSClusterSpecifierPluginForTesting unregisters the RLS Cluster
// Specifier Plugin for testing purposes. This is needed because there is no way
// to unregister the RLS Cluster Specifier Plugin after registering it solely
// for testing purposes using RegisterRLSClusterSpecifierPluginForTesting().
//
// TODO: Remove this function once the RLS env var is removed.
UnregisterRLSClusterSpecifierPluginForTesting func()
// RegisterRBACHTTPFilterForTesting registers the RBAC HTTP Filter for testing
// purposes, regardless of the RBAC environment variable.
//
// TODO: Remove this function once the RBAC env var is removed.
RegisterRBACHTTPFilterForTesting func()
// UnregisterRBACHTTPFilterForTesting unregisters the RBAC HTTP Filter for
// testing purposes. This is needed because there is no way to unregister the
// HTTP Filter after registering it solely for testing purposes using
// RegisterRBACHTTPFilterForTesting().
//
// TODO: Remove this function once the RBAC env var is removed.
UnregisterRBACHTTPFilterForTesting func()
// ORCAAllowAnyMinReportingInterval is for examples/orca use ONLY. // ORCAAllowAnyMinReportingInterval is for examples/orca use ONLY.
ORCAAllowAnyMinReportingInterval any // func(so *orca.ServiceOptions) ORCAAllowAnyMinReportingInterval any // func(so *orca.ServiceOptions)
@@ -240,22 +211,11 @@ var (
// default resolver scheme. // default resolver scheme.
UserSetDefaultScheme = false UserSetDefaultScheme = false
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
// SnapshotMetricRegistryForTesting snapshots the global data of the metric // SnapshotMetricRegistryForTesting snapshots the global data of the metric
// registry. Returns a cleanup function that sets the metric registry to its // registry. Returns a cleanup function that sets the metric registry to its
// original state. Only called in testing functions. // original state. Only called in testing functions.
SnapshotMetricRegistryForTesting func() func() SnapshotMetricRegistryForTesting func() func()
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
// testing purposes.
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for // SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes. // testing purposes.
SetBufferPoolingThresholdForTesting any // func(int) SetBufferPoolingThresholdForTesting any // func(int)
@@ -273,6 +233,18 @@ var (
// When set, the function will be called before the stream enters // When set, the function will be called before the stream enters
// the blocking state. // the blocking state.
NewStreamWaitingForResolver = func() {} NewStreamWaitingForResolver = func() {}
// AddressToTelemetryLabels is an xDS-provided function to extract telemetry
// labels from a resolver.Address. Callers must assert its type before calling.
AddressToTelemetryLabels any // func(addr resolver.Address) map[string]string
// AsyncReporterCleanupDelegate is initialized to a pass-through function by
// default (production behavior), allowing tests to swap it with an
// implementation which tracks registration of async reporter and its
// corresponding cleanup.
AsyncReporterCleanupDelegate = func(cleanup func()) func() {
return cleanup
}
) )
// HealthChecker defines the signature of the client-side LB channel health // HealthChecker defines the signature of the client-side LB channel health
@@ -320,3 +292,9 @@ type EnforceClientConnEmbedding interface {
type Timer interface { type Timer interface {
Stop() bool Stop() bool
} }
// EnforceMetricsRecorderEmbedding is used to enforce proper MetricsRecorder
// implementation embedding.
type EnforceMetricsRecorderEmbedding interface {
enforceMetricsRecorderEmbedding()
}

View File

@@ -22,11 +22,13 @@ package delegatingresolver
import ( import (
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/proxyattributes" "google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/internal/transport/networktype"
@@ -40,6 +42,8 @@ var (
HTTPSProxyFromEnvironment = http.ProxyFromEnvironment HTTPSProxyFromEnvironment = http.ProxyFromEnvironment
) )
const defaultPort = "443"
// delegatingResolver manages both target URI and proxy address resolution by // delegatingResolver manages both target URI and proxy address resolution by
// delegating these tasks to separate child resolvers. Essentially, it acts as // delegating these tasks to separate child resolvers. Essentially, it acts as
// an intermediary between the gRPC ClientConn and the child resolvers. // an intermediary between the gRPC ClientConn and the child resolvers.
@@ -107,10 +111,18 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
targetResolver: nopResolver{}, targetResolver: nopResolver{},
} }
addr := target.Endpoint()
var err error var err error
r.proxyURL, err = proxyURLForTarget(target.Endpoint()) if target.URL.Scheme == "dns" && !targetResolutionEnabled && envconfig.EnableDefaultPortForProxyTarget {
addr, err = parseTarget(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %s: %v", target, err) return nil, fmt.Errorf("delegating_resolver: invalid target address %q: %v", target.Endpoint(), err)
}
}
r.proxyURL, err = proxyURLForTarget(addr)
if err != nil {
return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %q: %v", target, err)
} }
// proxy is not configured or proxy address excluded using `NO_PROXY` env // proxy is not configured or proxy address excluded using `NO_PROXY` env
@@ -132,8 +144,8 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
// bypass the target resolver and store the unresolved target address. // bypass the target resolver and store the unresolved target address.
if target.URL.Scheme == "dns" && !targetResolutionEnabled { if target.URL.Scheme == "dns" && !targetResolutionEnabled {
r.targetResolverState = &resolver.State{ r.targetResolverState = &resolver.State{
Addresses: []resolver.Address{{Addr: target.Endpoint()}}, Addresses: []resolver.Address{{Addr: addr}},
Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: target.Endpoint()}}}}, Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: addr}}}},
} }
r.updateTargetResolverState(*r.targetResolverState) r.updateTargetResolverState(*r.targetResolverState)
return r, nil return r, nil
@@ -202,6 +214,44 @@ func needsProxyResolver(state *resolver.State) bool {
return false return false
} }
// parseTarget takes a target string and ensures it is a valid "host:port" target.
//
// It does the following:
// 1. If the target already has a port (e.g., "host:port", "[ipv6]:port"),
// it is returned as is.
// 2. If the host part is empty (e.g., ":80"), it defaults to "localhost",
// returning "localhost:80".
// 3. If the target is missing a port (e.g., "host", "ipv6"), the defaultPort
// is added.
//
// An error is returned for empty targets or targets with a trailing colon
// but no port (e.g., "host:").
func parseTarget(target string) (string, error) {
if target == "" {
return "", fmt.Errorf("missing address")
}
host, port, err := net.SplitHostPort(target)
if err != nil {
// If SplitHostPort fails, it's likely because the port is missing.
// We append the default port and return the result.
return net.JoinHostPort(target, defaultPort), nil
}
// If SplitHostPort succeeds, we check for edge cases.
if port == "" {
// A success with an empty port means the target had a trailing colon,
// e.g., "host:", which is an error.
return "", fmt.Errorf("missing port after port-separator colon")
}
if host == "" {
// A success with an empty host means the target was like ":80".
// We default the host to "localhost".
host = "localhost"
}
return net.JoinHostPort(host, port), nil
}
func skipProxy(address resolver.Address) bool { func skipProxy(address resolver.Address) bool {
// Avoid proxy when network is not tcp. // Avoid proxy when network is not tcp.
networkType, ok := networktype.Get(address) networkType, ok := networktype.Get(address)

View File

@@ -125,7 +125,10 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
// IP address. // IP address.
if ipAddr, err := formatIP(host); err == nil { if ipAddr, err := formatIP(host); err == nil {
addr := []resolver.Address{{Addr: ipAddr + ":" + port}} addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
cc.UpdateState(resolver.State{Addresses: addr}) cc.UpdateState(resolver.State{
Addresses: addr,
Endpoints: []resolver.Endpoint{{Addresses: addr}},
})
return deadResolver{}, nil return deadResolver{}, nil
} }
@@ -138,7 +141,7 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
cancel: cancel, cancel: cancel,
cc: cc, cc: cc,
rn: make(chan struct{}, 1), rn: make(chan struct{}, 1),
disableServiceConfig: opts.DisableServiceConfig, enableServiceConfig: envconfig.EnableTXTServiceConfig && !opts.DisableServiceConfig,
} }
d.resolver, err = internal.NewNetResolver(target.URL.Host) d.resolver, err = internal.NewNetResolver(target.URL.Host)
@@ -182,7 +185,7 @@ type dnsResolver struct {
// function pointers) inside watcher() goroutine has data race with // function pointers) inside watcher() goroutine has data race with
// replaceNetFunc (WRITE the lookup function pointers). // replaceNetFunc (WRITE the lookup function pointers).
wg sync.WaitGroup wg sync.WaitGroup
disableServiceConfig bool enableServiceConfig bool
} }
// ResolveNow invoke an immediate resolution of the target that this // ResolveNow invoke an immediate resolution of the target that this
@@ -342,11 +345,19 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
return nil, hostErr return nil, hostErr
} }
state := resolver.State{Addresses: addrs} eps := make([]resolver.Endpoint, 0, len(addrs))
for _, addr := range addrs {
eps = append(eps, resolver.Endpoint{Addresses: []resolver.Address{addr}})
}
state := resolver.State{
Addresses: addrs,
Endpoints: eps,
}
if len(srv) > 0 { if len(srv) > 0 {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv}) state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
} }
if !d.disableServiceConfig { if d.enableServiceConfig {
state.ServiceConfig = d.lookupTXT(ctx) state.ServiceConfig = d.lookupTXT(ctx)
} }
return &state, nil return &state, nil

View File

@@ -20,6 +20,7 @@ import (
"fmt" "fmt"
estats "google.golang.org/grpc/experimental/stats" estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
) )
@@ -28,6 +29,7 @@ import (
// It eats any record calls where the label values provided do not match the // It eats any record calls where the label values provided do not match the
// number of label keys. // number of label keys.
type MetricsRecorderList struct { type MetricsRecorderList struct {
internal.EnforceMetricsRecorderEmbedding
// metricsRecorders are the metrics recorders this list will forward to. // metricsRecorders are the metrics recorders this list will forward to.
metricsRecorders []estats.MetricsRecorder metricsRecorders []estats.MetricsRecorder
} }
@@ -64,6 +66,16 @@ func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle,
} }
} }
// RecordInt64UpDownCount records the measurement alongside labels on the int
// count associated with the provided handle.
func (l *MetricsRecorderList) RecordInt64UpDownCount(handle *estats.Int64UpDownCountHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64UpDownCount(handle, incr, labels...)
}
}
// RecordFloat64Count records the measurement alongside labels on the float // RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle. // count associated with the provided handle.
func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) { func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) {
@@ -103,3 +115,61 @@ func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle,
metricRecorder.RecordInt64Gauge(handle, incr, labels...) metricRecorder.RecordInt64Gauge(handle, incr, labels...)
} }
} }
// RegisterAsyncReporter forwards the registration to all underlying metrics
// recorders.
//
// It returns a cleanup function that, when called, invokes the cleanup function
// returned by each underlying recorder, ensuring the reporter is unregistered
// from all of them.
func (l *MetricsRecorderList) RegisterAsyncReporter(reporter estats.AsyncMetricReporter, metrics ...estats.AsyncMetric) func() {
descriptorsMap := make(map[*estats.MetricDescriptor]bool, len(metrics))
for _, m := range metrics {
descriptorsMap[m.Descriptor()] = true
}
unregisterFns := make([]func(), 0, len(l.metricsRecorders))
for _, mr := range l.metricsRecorders {
// Wrap the AsyncMetricsRecorder to intercept calls to RecordInt64Gauge
// and validate the labels.
wrappedCallback := func(recorder estats.AsyncMetricsRecorder) error {
wrappedRecorder := &asyncRecorderWrapper{
delegate: recorder,
descriptors: descriptorsMap,
}
return reporter.Report(wrappedRecorder)
}
unregisterFns = append(unregisterFns, mr.RegisterAsyncReporter(estats.AsyncMetricReporterFunc(wrappedCallback), metrics...))
}
// Wrap the cleanup function using the internal delegate.
// In production, this returns realCleanup as-is.
// In tests, the leak checker can swap this to track the registration lifetime.
return internal.AsyncReporterCleanupDelegate(defaultCleanUp(unregisterFns))
}
func defaultCleanUp(unregisterFns []func()) func() {
return func() {
for _, unregister := range unregisterFns {
unregister()
}
}
}
type asyncRecorderWrapper struct {
delegate estats.AsyncMetricsRecorder
descriptors map[*estats.MetricDescriptor]bool
}
// RecordIntAsync64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
func (w *asyncRecorderWrapper) RecordInt64AsyncGauge(handle *estats.Int64AsyncGaugeHandle, value int64, labels ...string) {
// Ensure only metrics for descriptors passed during callback registration
// are emitted.
d := handle.Descriptor()
if _, ok := w.descriptors[d]; !ok {
return
}
// Validate labels and delegate.
verifyLabels(d, labels...)
w.delegate.RecordInt64AsyncGauge(handle, value, labels...)
}

70
vendor/google.golang.org/grpc/internal/stats/stats.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"context"
"google.golang.org/grpc/stats"
)
type combinedHandler struct {
handlers []stats.Handler
}
// NewCombinedHandler combines multiple stats.Handlers into a single handler.
//
// It returns nil if no handlers are provided. If only one handler is
// provided, it is returned directly without wrapping.
func NewCombinedHandler(handlers ...stats.Handler) stats.Handler {
switch len(handlers) {
case 0:
return nil
case 1:
return handlers[0]
default:
return &combinedHandler{handlers: handlers}
}
}
func (ch *combinedHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
for _, h := range ch.handlers {
ctx = h.TagRPC(ctx, info)
}
return ctx
}
func (ch *combinedHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
for _, h := range ch.handlers {
h.HandleRPC(ctx, stats)
}
}
func (ch *combinedHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
for _, h := range ch.handlers {
ctx = h.TagConn(ctx, info)
}
return ctx
}
func (ch *combinedHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
for _, h := range ch.handlers {
h.HandleConn(ctx, stats)
}
}

View File

@@ -24,30 +24,34 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// ClientStream implements streaming functionality for a gRPC client. // ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct { type ClientStream struct {
*Stream // Embed for common stream functionality. Stream // Embed for common stream functionality.
ct *http2Client ct *http2Client
done chan struct{} // closed at the end of stream to unblock writers. done chan struct{} // closed at the end of stream to unblock writers.
doneFunc func() // invoked at the end of stream. doneFunc func() // invoked at the end of stream.
headerChan chan struct{} // closed to indicate the end of header metadata. headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. header metadata.MD // the received header metadata
status *status.Status // the status error received from the server
// Non-pointer fields are at the end to optimize GC allocations.
// headerValid indicates whether a valid header was received. Only // headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before // meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). // reading its value).
headerValid bool headerValid bool
header metadata.MD // the received header metadata
noHeaders bool // set if the client never received headers (set only after the stream is done). noHeaders bool // set if the client never received headers (set only after the stream is done).
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
statsHandler stats.Handler // nil for internal streams (e.g., health check, ORCA) where telemetry is not supported.
status *status.Status // the status error received from the server
} }
// Read reads an n byte message from the input stream. // Read reads an n byte message from the input stream.
@@ -142,3 +146,11 @@ func (s *ClientStream) TrailersOnly() bool {
func (s *ClientStream) Status() *status.Status { func (s *ClientStream) Status() *status.Status {
return s.status return s.status
} }
func (s *ClientStream) requestRead(n int) {
s.ct.adjustWindow(s, uint32(n))
}
func (s *ClientStream) updateWindow(n int) {
s.ct.updateWindow(s, uint32(n))
}

View File

@@ -24,16 +24,13 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
) )
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
@@ -147,11 +144,9 @@ type cleanupStream struct {
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
type earlyAbortStream struct { type earlyAbortStream struct {
httpStatus uint32
streamID uint32 streamID uint32
contentSubtype string
status *status.Status
rst bool rst bool
hf []hpack.HeaderField // Pre-built header fields
} }
func (*earlyAbortStream) isTransportResponseFrame() bool { return false } func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
@@ -496,6 +491,16 @@ const (
serverSide serverSide
) )
// maxWriteBufSize is the maximum length (number of elements) the cached
// writeBuf can grow to. The length depends on the number of buffers
// contained within the BufferSlice produced by the codec, which is
// generally small.
//
// If a writeBuf larger than this limit is required, it will be allocated
// and freed after use, rather than being cached. This avoids holding
// on to large amounts of memory.
const maxWriteBufSize = 64
// Loopy receives frames from the control buffer. // Loopy receives frames from the control buffer.
// Each frame is handled individually; most of the work done by loopy goes // Each frame is handled individually; most of the work done by loopy goes
// into handling data frames. Loopy maintains a queue of active streams, and each // into handling data frames. Loopy maintains a queue of active streams, and each
@@ -530,6 +535,8 @@ type loopyWriter struct {
// Side-specific handlers // Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error) ssGoAwayHandler func(*goAway) (bool, error)
writeBuf [][]byte // cached slice to avoid heap allocations for calls to mem.Reader.Peek.
} }
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter { func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
@@ -669,7 +676,6 @@ func (l *loopyWriter) registerStreamHandler(h *registerStream) {
state: empty, state: empty,
itl: &itemList{}, itl: &itemList{},
wq: h.wq, wq: h.wq,
reader: mem.BufferSlice{}.Reader(),
} }
l.estdStreams[h.streamID] = str l.estdStreams[h.streamID] = str
} }
@@ -705,7 +711,6 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
state: empty, state: empty,
itl: &itemList{}, itl: &itemList{},
wq: h.wq, wq: h.wq,
reader: mem.BufferSlice{}.Reader(),
} }
return l.originateStream(str, h) return l.originateStream(str, h)
} }
@@ -833,18 +838,7 @@ func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if l.side == clientSide { if l.side == clientSide {
return errors.New("earlyAbortStream not handled on client") return errors.New("earlyAbortStream not handled on client")
} }
// In case the caller forgets to set the http status, default to 200. if err := l.writeHeader(eas.streamID, true, eas.hf, nil); err != nil {
if eas.httpStatus == 0 {
eas.httpStatus = 200
}
headerFields := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
}
if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
return err return err
} }
if eas.rst { if eas.rst {
@@ -948,11 +942,11 @@ func (l *loopyWriter) processData() (bool, error) {
if str == nil { if str == nil {
return true, nil return true, nil
} }
reader := str.reader reader := &str.reader
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream. dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
if !dataItem.processing { if !dataItem.processing {
dataItem.processing = true dataItem.processing = true
str.reader.Reset(dataItem.data) reader.Reset(dataItem.data)
dataItem.data.Free() dataItem.data.Free()
} }
// A data item is represented by a dataFrame, since it later translates into // A data item is represented by a dataFrame, since it later translates into
@@ -964,11 +958,11 @@ func (l *loopyWriter) processData() (bool, error) {
if len(dataItem.h) == 0 && reader.Remaining() == 0 { // Empty data frame if len(dataItem.h) == 0 && reader.Remaining() == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true // Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { if err := l.framer.writeData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err return false, err
} }
str.itl.dequeue() // remove the empty data item from stream str.itl.dequeue() // remove the empty data item from stream
_ = reader.Close() reader.Close()
if str.itl.isEmpty() { if str.itl.isEmpty() {
str.state = empty str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
@@ -1001,25 +995,20 @@ func (l *loopyWriter) processData() (bool, error) {
remainingBytes := len(dataItem.h) + reader.Remaining() - hSize - dSize remainingBytes := len(dataItem.h) + reader.Remaining() - hSize - dSize
size := hSize + dSize size := hSize + dSize
var buf *[]byte l.writeBuf = l.writeBuf[:0]
if hSize > 0 {
if hSize != 0 && dSize == 0 { l.writeBuf = append(l.writeBuf, dataItem.h[:hSize])
buf = &dataItem.h }
} else { if dSize > 0 {
// Note: this is only necessary because the http2.Framer does not support var err error
// partially writing a frame, so the sequence must be materialized into a buffer. l.writeBuf, err = reader.Peek(dSize, l.writeBuf)
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed. if err != nil {
pool := l.bufferPool // This must never happen since the reader must have at least dSize
if pool == nil { // bytes.
// Note that this is only supposed to be nil in tests. Otherwise, stream is // Log an error to fail tests.
// always initialized with a BufferPool. l.logger.Errorf("unexpected error while reading Data frame payload: %v", err)
pool = mem.DefaultBufferPool() return false, err
} }
buf = pool.Get(size)
defer pool.Put(buf)
copy((*buf)[:hSize], dataItem.h)
_, _ = reader.Read((*buf)[hSize:])
} }
// Now that outgoing flow controls are checked we can replenish str's write quota // Now that outgoing flow controls are checked we can replenish str's write quota
@@ -1032,7 +1021,14 @@ func (l *loopyWriter) processData() (bool, error) {
if dataItem.onEachWrite != nil { if dataItem.onEachWrite != nil {
dataItem.onEachWrite() dataItem.onEachWrite()
} }
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil { err := l.framer.writeData(dataItem.streamID, endStream, l.writeBuf)
reader.Discard(dSize)
if cap(l.writeBuf) > maxWriteBufSize {
l.writeBuf = nil
} else {
clear(l.writeBuf)
}
if err != nil {
return false, err return false, err
} }
str.bytesOutStanding += size str.bytesOutStanding += size
@@ -1040,7 +1036,7 @@ func (l *loopyWriter) processData() (bool, error) {
dataItem.h = dataItem.h[hSize:] dataItem.h = dataItem.h[hSize:]
if remainingBytes == 0 { // All the data from that message was written out. if remainingBytes == 0 { // All the data from that message was written out.
_ = reader.Close() reader.Close()
str.itl.dequeue() str.itl.dequeue()
} }
if str.itl.isEmpty() { if str.itl.isEmpty() {

View File

@@ -28,7 +28,7 @@ import (
// writeQuota is a soft limit on the amount of data a stream can // writeQuota is a soft limit on the amount of data a stream can
// schedule before some of it is written out. // schedule before some of it is written out.
type writeQuota struct { type writeQuota struct {
quota int32 _ noCopy
// get waits on read from when quota goes less than or equal to zero. // get waits on read from when quota goes less than or equal to zero.
// replenish writes on it when quota goes positive again. // replenish writes on it when quota goes positive again.
ch chan struct{} ch chan struct{}
@@ -38,16 +38,17 @@ type writeQuota struct {
// It is implemented as a field so that it can be updated // It is implemented as a field so that it can be updated
// by tests. // by tests.
replenish func(n int) replenish func(n int)
quota int32
} }
func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { // init allows a writeQuota to be initialized in-place, which is useful for
w := &writeQuota{ // resetting a buffer or for avoiding a heap allocation when the buffer is
quota: sz, // embedded in another struct.
ch: make(chan struct{}, 1), func (w *writeQuota) init(sz int32, done <-chan struct{}) {
done: done, w.quota = sz
} w.ch = make(chan struct{}, 1)
w.done = done
w.replenish = w.realReplenish w.replenish = w.realReplenish
return w
} }
func (w *writeQuota) get(sz int32) error { func (w *writeQuota) get(sz int32) error {
@@ -67,9 +68,9 @@ func (w *writeQuota) get(sz int32) error {
func (w *writeQuota) realReplenish(n int) { func (w *writeQuota) realReplenish(n int) {
sz := int32(n) sz := int32(n)
a := atomic.AddInt32(&w.quota, sz) newQuota := atomic.AddInt32(&w.quota, sz)
b := a - sz previousQuota := newQuota - sz
if b <= 0 && a > 0 { if previousQuota <= 0 && newQuota > 0 {
select { select {
case w.ch <- struct{}{}: case w.ch <- struct{}{}:
default: default:

View File

@@ -50,7 +50,7 @@ import (
// NewServerHandlerTransport returns a ServerTransport handling gRPC from // NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error. // inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2. // It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) { func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost) w.Header().Set("Allow", http.MethodPost)
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
@@ -170,7 +170,7 @@ type serverHandlerTransport struct {
// TODO make sure this is consistent across handler_server and http2_server // TODO make sure this is consistent across handler_server and http2_server
contentSubtype string contentSubtype string
stats []stats.Handler stats stats.Handler
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
bufferPool mem.BufferPool bufferPool mem.BufferPool
@@ -274,14 +274,14 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
} }
}) })
if err == nil { // transport has not been closed if err == nil && ht.stats != nil { // transport has not been closed
// Note: The trailer fields are compressed with hpack after this call returns. // Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
for _, sh := range ht.stats { s.hdrMu.Lock()
sh.HandleRPC(s.Context(), &stats.OutTrailer{ ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(), Trailer: s.trailer.Copy(),
}) })
} s.hdrMu.Unlock()
} }
ht.Close(errors.New("finished writing status")) ht.Close(errors.New("finished writing status"))
return err return err
@@ -372,19 +372,23 @@ func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) e
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
}) })
if err == nil { if err == nil && ht.stats != nil {
for _, sh := range ht.stats {
// Note: The header fields are compressed with hpack after this call returns. // Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
sh.HandleRPC(s.Context(), &stats.OutHeader{ ht.stats.HandleRPC(s.Context(), &stats.OutHeader{
Header: md.Copy(), Header: md.Copy(),
Compression: s.sendCompress, Compression: s.sendCompress,
}) })
} }
}
return err return err
} }
func (ht *serverHandlerTransport) adjustWindow(*ServerStream, uint32) {
}
func (ht *serverHandlerTransport) updateWindow(*ServerStream, uint32) {
}
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) { func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
// With this transport type there will be exactly 1 stream: this HTTP request. // With this transport type there will be exactly 1 stream: this HTTP request.
var cancel context.CancelFunc var cancel context.CancelFunc
@@ -409,11 +413,9 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req req := ht.req
s := &ServerStream{ s := &ServerStream{
Stream: &Stream{ Stream: Stream{
id: 0, // irrelevant id: 0, // irrelevant
ctx: ctx, ctx: ctx,
requestRead: func(int) {},
buf: newRecvBuffer(),
method: req.URL.Path, method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"), recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype, contentSubtype: ht.contentSubtype,
@@ -422,9 +424,11 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
st: ht, st: ht,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997. headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
} }
s.trReader = &transportReader{ s.Stream.buf.init()
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, s.readRequester = s
windowHandler: func(int) {}, s.trReader = transportReader{
reader: recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: &s.buf},
windowHandler: s,
} }
// readerDone is closed when the Body.Read-ing goroutine exits. // readerDone is closed when the Body.Read-ing goroutine exits.

View File

@@ -44,6 +44,7 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata" imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/proxyattributes" "google.golang.org/grpc/internal/proxyattributes"
istats "google.golang.org/grpc/internal/stats"
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
isyscall "google.golang.org/grpc/internal/syscall" isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/internal/transport/networktype"
@@ -105,7 +106,7 @@ type http2Client struct {
kp keepalive.ClientParameters kp keepalive.ClientParameters
keepaliveEnabled bool keepaliveEnabled bool
statsHandlers []stats.Handler statsHandler stats.Handler
initialWindowSize int32 initialWindowSize int32
@@ -335,14 +336,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
goAway: make(chan struct{}), goAway: make(chan struct{}),
keepaliveDone: make(chan struct{}), keepaliveDone: make(chan struct{}),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize), framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize, opts.BufferPool),
fc: &trInFlow{limit: uint32(icwz)}, fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme, scheme: scheme,
activeStreams: make(map[uint32]*ClientStream), activeStreams: make(map[uint32]*ClientStream),
isSecure: isSecure, isSecure: isSecure,
perRPCCreds: perRPCCreds, perRPCCreds: perRPCCreds,
kp: kp, kp: kp,
statsHandlers: opts.StatsHandlers, statsHandler: istats.NewCombinedHandler(opts.StatsHandlers...),
initialWindowSize: initialWindowSize, initialWindowSize: initialWindowSize,
nextID: 1, nextID: 1,
maxConcurrentStreams: defaultMaxStreamsClient, maxConcurrentStreams: defaultMaxStreamsClient,
@@ -369,7 +370,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}) })
t.logger = prefixLoggerForClientTransport(t) t.logger = prefixLoggerForClientTransport(t)
// Add peer information to the http2client context. // Add peer information to the http2client context.
t.ctx = peer.NewContext(t.ctx, t.getPeer()) t.ctx = peer.NewContext(t.ctx, t.Peer())
if md, ok := addr.Metadata.(*metadata.MD); ok { if md, ok := addr.Metadata.(*metadata.MD); ok {
t.md = *md t.md = *md
@@ -386,15 +387,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
updateFlowControl: t.updateFlowControl, updateFlowControl: t.updateFlowControl,
} }
} }
for _, sh := range t.statsHandlers { if t.statsHandler != nil {
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr, RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr, LocalAddr: t.localAddr,
}) })
connBegin := &stats.ConnBegin{ t.statsHandler.HandleConn(t.ctx, &stats.ConnBegin{
Client: true, Client: true,
} })
sh.HandleConn(t.ctx, connBegin)
} }
if t.keepaliveEnabled { if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu) t.kpDormancyCond = sync.NewCond(&t.mu)
@@ -478,45 +478,40 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil return t, nil
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream { func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) *ClientStream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &ClientStream{ s := &ClientStream{
Stream: &Stream{ Stream: Stream{
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
contentSubtype: callHdr.ContentSubtype, contentSubtype: callHdr.ContentSubtype,
}, },
ct: t, ct: t,
done: make(chan struct{}), done: make(chan struct{}),
headerChan: make(chan struct{}), headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc, doneFunc: callHdr.DoneFunc,
statsHandler: handler,
} }
s.wq = newWriteQuota(defaultWriteQuota, s.done) s.Stream.buf.init()
s.requestRead = func(n int) { s.Stream.wq.init(defaultWriteQuota, s.done)
t.adjustWindow(s, uint32(n)) s.readRequester = s
}
// The client side stream context should have exactly the same life cycle with the user provided context. // The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy. // So we use the original context here instead of creating a copy.
s.ctx = ctx s.ctx = ctx
s.trReader = &transportReader{ s.trReader = transportReader{
reader: &recvBufferReader{ reader: recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctx.Done(), ctxDone: s.ctx.Done(),
recv: s.buf, recv: &s.buf,
closeStream: func(err error) { clientStream: s,
s.Close(err)
},
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
}, },
windowHandler: s,
} }
return s return s
} }
func (t *http2Client) getPeer() *peer.Peer { func (t *http2Client) Peer() *peer.Peer {
return &peer.Peer{ return &peer.Peer{
Addr: t.remoteAddr, Addr: t.remoteAddr,
AuthInfo: t.authInfo, // Can be nil AuthInfo: t.authInfo, // Can be nil
@@ -556,6 +551,22 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
// Make the slice of certain predictable size to reduce allocations made by append. // Make the slice of certain predictable size to reduce allocations made by append.
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData) hfLen += len(authData) + len(callAuthData)
registeredCompressors := t.registeredCompressors
if callHdr.AcceptedCompressors != nil {
registeredCompressors = *callHdr.AcceptedCompressors
}
if callHdr.PreviousAttempts > 0 {
hfLen++
}
if callHdr.SendCompress != "" {
hfLen++
}
if registeredCompressors != "" {
hfLen++
}
if _, ok := ctx.Deadline(); ok {
hfLen++
}
headerFields := make([]hpack.HeaderField, 0, hfLen) headerFields := make([]hpack.HeaderField, 0, hfLen)
headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"}) headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
@@ -568,7 +579,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)})
} }
registeredCompressors := t.registeredCompressors
if callHdr.SendCompress != "" { if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
// Include the outgoing compressor name when compressor is not registered // Include the outgoing compressor name when compressor is not registered
@@ -735,8 +745,8 @@ func (e NewStreamError) Error() string {
// NewStream creates a stream and registers it into the transport as "active" // NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError. // streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) { func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.getPeer()) ctx = peer.NewContext(ctx, t.Peer())
// ServerName field of the resolver returned address takes precedence over // ServerName field of the resolver returned address takes precedence over
// Host field of CallHdr to determine the :authority header. This is because, // Host field of CallHdr to determine the :authority header. This is because,
@@ -772,7 +782,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
if err != nil { if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false} return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
} }
s := t.newStream(ctx, callHdr) s := t.newStream(ctx, callHdr, handler)
cleanup := func(err error) { cleanup := func(err error) {
if s.swapState(streamDone) == streamDone { if s.swapState(streamDone) == streamDone {
// If it was already done, return. // If it was already done, return.
@@ -811,7 +821,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
return nil return nil
}, },
onOrphaned: cleanup, onOrphaned: cleanup,
wq: s.wq, wq: &s.wq,
} }
firstTry := true firstTry := true
var ch chan struct{} var ch chan struct{}
@@ -842,7 +852,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
transportDrainRequired = t.nextID > MaxStreamID transportDrainRequired = t.nextID > MaxStreamID
s.id = hdr.streamID s.id = hdr.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)} s.fc = inFlow{limit: uint32(t.initialWindowSize)}
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
@@ -893,27 +903,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true}
} }
} }
if len(t.statsHandlers) != 0 { if s.statsHandler != nil {
header, ok := metadata.FromOutgoingContext(ctx) header, ok := metadata.FromOutgoingContext(ctx)
if ok { if ok {
header.Set("user-agent", t.userAgent) header.Set("user-agent", t.userAgent)
} else { } else {
header = metadata.Pairs("user-agent", t.userAgent) header = metadata.Pairs("user-agent", t.userAgent)
} }
for _, sh := range t.statsHandlers {
// Note: The header fields are compressed with hpack after this call returns. // Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
// Note: Creating a new stats object to prevent pollution. s.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{
outHeader := &stats.OutHeader{
Client: true, Client: true,
FullMethod: callHdr.Method, FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr, RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr, LocalAddr: t.localAddr,
Compression: callHdr.SendCompress, Compression: callHdr.SendCompress,
Header: header, Header: header,
} })
sh.HandleRPC(s.ctx, outHeader)
}
} }
if transportDrainRequired { if transportDrainRequired {
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
@@ -990,6 +996,9 @@ func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode
// accessed anymore. // accessed anymore.
func (t *http2Client) Close(err error) { func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
// For background on the deadline value chosen here, see
// https://github.com/grpc/grpc-go/issues/8425#issuecomment-3057938248 .
t.conn.SetReadDeadline(time.Now().Add(time.Second))
t.mu.Lock() t.mu.Lock()
// Make sure we only close once. // Make sure we only close once.
if t.state == closing { if t.state == closing {
@@ -1051,11 +1060,10 @@ func (t *http2Client) Close(err error) {
for _, s := range streams { for _, s := range streams {
t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false)
} }
for _, sh := range t.statsHandlers { if t.statsHandler != nil {
connEnd := &stats.ConnEnd{ t.statsHandler.HandleConn(t.ctx, &stats.ConnEnd{
Client: true, Client: true,
} })
sh.HandleConn(t.ctx, connEnd)
} }
} }
@@ -1166,7 +1174,7 @@ func (t *http2Client) updateFlowControl(n uint32) {
}) })
} }
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *parsedDataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool var sendBDPPing bool
if t.bdpEst != nil { if t.bdpEst != nil {
@@ -1210,22 +1218,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
return return
} }
dataLen := f.data.Len()
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(size - uint32(dataLen)); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
} }
} }
// TODO(bradfitz, zhaoq): A copy is required here because there is no if dataLen > 0 {
// guarantee f.Data() is consumed before the arrival of next frame. f.data.Ref()
// Can this copy be eliminated? s.write(recvMsg{buffer: f.data})
if len(f.Data()) > 0 {
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
@@ -1465,17 +1466,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
contentTypeErr = "malformed header: missing HTTP content-type" contentTypeErr = "malformed header: missing HTTP content-type"
grpcMessage string grpcMessage string
recvCompress string recvCompress string
httpStatusCode *int
httpStatusErr string httpStatusErr string
rawStatusCode = codes.Unknown // the code from the grpc-status header, if present
grpcStatusCode = codes.Unknown
// headerError is set if an error is encountered while parsing the headers // headerError is set if an error is encountered while parsing the headers
headerError string headerError string
httpStatus string
) )
if initialHeader {
httpStatusErr = "malformed header: missing HTTP status"
}
for _, hf := range frame.Fields { for _, hf := range frame.Fields {
switch hf.Name { switch hf.Name {
case "content-type": case "content-type":
@@ -1491,35 +1489,15 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
case "grpc-status": case "grpc-status":
code, err := strconv.ParseInt(hf.Value, 10, 32) code, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil { if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) se := status.New(codes.Unknown, fmt.Sprintf("transport: malformed grpc-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return return
} }
rawStatusCode = codes.Code(uint32(code)) grpcStatusCode = codes.Code(uint32(code))
case "grpc-message": case "grpc-message":
grpcMessage = decodeGrpcMessage(hf.Value) grpcMessage = decodeGrpcMessage(hf.Value)
case ":status": case ":status":
if hf.Value == "200" { httpStatus = hf.Value
httpStatusErr = ""
statusCode := 200
httpStatusCode = &statusCode
break
}
c, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
statusCode := int(c)
httpStatusCode = &statusCode
httpStatusErr = fmt.Sprintf(
"unexpected HTTP status code received from server: %d (%s)",
statusCode,
http.StatusText(statusCode),
)
default: default:
if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) {
break break
@@ -1534,25 +1512,52 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
} }
if !isGRPC || httpStatusErr != "" { // If a non-gRPC response is received, then evaluate the HTTP status to
var code = codes.Internal // when header does not include HTTP status, return INTERNAL // process the response and close the stream.
// In case http status doesn't provide any error information (status : 200),
if httpStatusCode != nil { // then evalute response code to be Unknown.
if !isGRPC {
var grpcErrorCode = codes.Internal
if httpStatus == "" {
httpStatusErr = "malformed header: missing HTTP status"
} else {
// Parse the status codes (e.g. "200", 404").
statusCode, err := strconv.Atoi(httpStatus)
if err != nil {
se := status.New(grpcErrorCode, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
if statusCode >= 100 && statusCode < 200 {
if endStream {
se := status.New(codes.Internal, fmt.Sprintf(
"protocol error: informational header with status code %d must not have END_STREAM set", statusCode))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
}
// In case of informational headers, return.
return
}
httpStatusErr = fmt.Sprintf(
"unexpected HTTP status code received from server: %d (%s)",
statusCode,
http.StatusText(statusCode),
)
var ok bool var ok bool
code, ok = HTTPStatusConvTab[*httpStatusCode] grpcErrorCode, ok = HTTPStatusConvTab[statusCode]
if !ok { if !ok {
code = codes.Unknown grpcErrorCode = codes.Unknown
} }
} }
var errs []string var errs []string
if httpStatusErr != "" { if httpStatusErr != "" {
errs = append(errs, httpStatusErr) errs = append(errs, httpStatusErr)
} }
if contentTypeErr != "" { if contentTypeErr != "" {
errs = append(errs, contentTypeErr) errs = append(errs, contentTypeErr)
} }
// Verify the HTTP response is a 200.
se := status.New(code, strings.Join(errs, "; ")) se := status.New(grpcErrorCode, strings.Join(errs, "; "))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return return
} }
@@ -1583,22 +1588,20 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
} }
for _, sh := range t.statsHandlers { if s.statsHandler != nil {
if !endStream { if !endStream {
inHeader := &stats.InHeader{ s.statsHandler.HandleRPC(s.ctx, &stats.InHeader{
Client: true, Client: true,
WireLength: int(frame.Header().Length), WireLength: int(frame.Header().Length),
Header: metadata.MD(mdata).Copy(), Header: metadata.MD(mdata).Copy(),
Compression: s.recvCompress, Compression: s.recvCompress,
} })
sh.HandleRPC(s.ctx, inHeader)
} else { } else {
inTrailer := &stats.InTrailer{ s.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{
Client: true, Client: true,
WireLength: int(frame.Header().Length), WireLength: int(frame.Header().Length),
Trailer: metadata.MD(mdata).Copy(), Trailer: metadata.MD(mdata).Copy(),
} })
sh.HandleRPC(s.ctx, inTrailer)
} }
} }
@@ -1606,7 +1609,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return return
} }
status := istatus.NewWithProto(rawStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader]) status := istatus.NewWithProto(grpcStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader])
// If client received END_STREAM from server while stream was still active, // If client received END_STREAM from server while stream was still active,
// send RST_STREAM. // send RST_STREAM.
@@ -1653,7 +1656,7 @@ func (t *http2Client) reader(errCh chan<- error) {
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
t.controlBuf.throttle() t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.readFrame()
if t.keepaliveEnabled { if t.keepaliveEnabled {
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
} }
@@ -1668,7 +1671,7 @@ func (t *http2Client) reader(errCh chan<- error) {
if s != nil { if s != nil {
// use error detail to provide better err message // use error detail to provide better err message
code := http2ErrConvTab[se.Code] code := http2ErrConvTab[se.Code]
errorDetail := t.framer.fr.ErrorDetail() errorDetail := t.framer.errorDetail()
var msg string var msg string
if errorDetail != nil { if errorDetail != nil {
msg = errorDetail.Error() msg = errorDetail.Error()
@@ -1686,8 +1689,9 @@ func (t *http2Client) reader(errCh chan<- error) {
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
t.operateHeaders(frame) t.operateHeaders(frame)
case *http2.DataFrame: case *parsedDataFrame:
t.handleData(frame) t.handleData(frame)
frame.data.Free()
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
t.handleRSTStream(frame) t.handleRSTStream(frame)
case *http2.SettingsFrame: case *http2.SettingsFrame:
@@ -1807,8 +1811,6 @@ func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics {
} }
} }
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) incrMsgSent() { func (t *http2Client) incrMsgSent() {
if channelz.IsOn() { if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesSent.Add(1) t.channelz.SocketMetrics.MessagesSent.Add(1)

View File

@@ -35,6 +35,8 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
@@ -42,7 +44,6 @@ import (
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@@ -86,7 +87,7 @@ type http2Server struct {
// updates, reset streams, and various settings) to the controller. // updates, reset streams, and various settings) to the controller.
controlBuf *controlBuffer controlBuf *controlBuffer
fc *trInFlow fc *trInFlow
stats []stats.Handler stats stats.Handler
// Keepalive and max-age parameters for the server. // Keepalive and max-age parameters for the server.
kp keepalive.ServerParameters kp keepalive.ServerParameters
// Keepalive enforcement policy. // Keepalive enforcement policy.
@@ -168,7 +169,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
if config.MaxHeaderListSize != nil { if config.MaxHeaderListSize != nil {
maxHeaderListSize = *config.MaxHeaderListSize maxHeaderListSize = *config.MaxHeaderListSize
} }
framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize) framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize, config.BufferPool)
// Send initial settings as connection preface to client. // Send initial settings as connection preface to client.
isettings := []http2.Setting{{ isettings := []http2.Setting{{
ID: http2.SettingMaxFrameSize, ID: http2.SettingMaxFrameSize,
@@ -260,7 +261,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
fc: &trInFlow{limit: uint32(icwz)}, fc: &trInFlow{limit: uint32(icwz)},
state: reachable, state: reachable,
activeStreams: make(map[uint32]*ServerStream), activeStreams: make(map[uint32]*ServerStream),
stats: config.StatsHandlers, stats: config.StatsHandler,
kp: kp, kp: kp,
idle: time.Now(), idle: time.Now(),
kep: kep, kep: kep,
@@ -390,16 +391,15 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
} }
t.maxStreamID = streamID t.maxStreamID = streamID
buf := newRecvBuffer()
s := &ServerStream{ s := &ServerStream{
Stream: &Stream{ Stream: Stream{
id: streamID, id: streamID,
buf: buf, fc: inFlow{limit: uint32(t.initialWindowSize)},
fc: &inFlow{limit: uint32(t.initialWindowSize)},
}, },
st: t, st: t,
headerWireLength: int(frame.Header().Length), headerWireLength: int(frame.Header().Length),
} }
s.Stream.buf.init()
var ( var (
// if false, content-type was missing or invalid // if false, content-type was missing or invalid
isGRPC = false isGRPC = false
@@ -479,13 +479,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg) t.logger.Infof("Aborting the stream early: %v", errMsg)
} }
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusBadRequest, !frame.StreamEnded())
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
return nil return nil
} }
@@ -499,23 +493,11 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
return nil return nil
} }
if !isGRPC { if !isGRPC {
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(streamID, s.contentSubtype, status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), http.StatusUnsupportedMediaType, !frame.StreamEnded())
httpStatus: http.StatusUnsupportedMediaType,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType),
rst: !frame.StreamEnded(),
})
return nil return nil
} }
if headerError != nil { if headerError != nil {
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(streamID, s.contentSubtype, headerError, http.StatusBadRequest, !frame.StreamEnded())
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: headerError,
rst: !frame.StreamEnded(),
})
return nil return nil
} }
@@ -569,13 +551,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg) t.logger.Infof("Aborting the stream early: %v", errMsg)
} }
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusMethodNotAllowed, !frame.StreamEnded())
httpStatus: http.StatusMethodNotAllowed,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
s.cancel() s.cancel()
return nil return nil
} }
@@ -590,27 +566,16 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if !ok { if !ok {
stat = status.New(codes.PermissionDenied, err.Error()) stat = status.New(codes.PermissionDenied, err.Error())
} }
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(s.id, s.contentSubtype, stat, http.StatusOK, !frame.StreamEnded())
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
rst: !frame.StreamEnded(),
})
return nil return nil
} }
} }
if s.ctx.Err() != nil { if s.ctx.Err() != nil {
t.mu.Unlock() t.mu.Unlock()
st := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
// Early abort in case the timeout was zero or so low it already fired. // Early abort in case the timeout was zero or so low it already fired.
t.controlBuf.put(&earlyAbortStream{ t.writeEarlyAbort(s.id, s.contentSubtype, st, http.StatusOK, !frame.StreamEnded())
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
rst: !frame.StreamEnded(),
})
return nil return nil
} }
@@ -640,25 +605,21 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
t.channelz.SocketMetrics.StreamsStarted.Add(1) t.channelz.SocketMetrics.StreamsStarted.Add(1)
t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano()) t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano())
} }
s.requestRead = func(n int) { s.readRequester = s
t.adjustWindow(s, uint32(n))
}
s.ctxDone = s.ctx.Done() s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.Stream.wq.init(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = transportReader{
reader: &recvBufferReader{ reader: recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctxDone, ctxDone: s.ctxDone,
recv: s.buf, recv: &s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
}, },
windowHandler: s,
} }
// Register the stream with loopy. // Register the stream with loopy.
t.controlBuf.put(&registerStream{ t.controlBuf.put(&registerStream{
streamID: s.id, streamID: s.id,
wq: s.wq, wq: &s.wq,
}) })
handle(s) handle(s)
return nil return nil
@@ -674,7 +635,7 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre
}() }()
for { for {
t.controlBuf.throttle() t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.readFrame()
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
if err != nil { if err != nil {
if se, ok := err.(http2.StreamError); ok { if se, ok := err.(http2.StreamError); ok {
@@ -711,8 +672,9 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre
}) })
continue continue
} }
case *http2.DataFrame: case *parsedDataFrame:
t.handleData(frame) t.handleData(frame)
frame.data.Free()
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
t.handleRSTStream(frame) t.handleRSTStream(frame)
case *http2.SettingsFrame: case *http2.SettingsFrame:
@@ -792,7 +754,7 @@ func (t *http2Server) updateFlowControl(n uint32) {
} }
func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleData(f *parsedDataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool var sendBDPPing bool
if t.bdpEst != nil { if t.bdpEst != nil {
@@ -837,22 +799,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
t.closeStream(s, true, http2.ErrCodeFlowControl, false) t.closeStream(s, true, http2.ErrCodeFlowControl, false)
return return
} }
dataLen := f.data.Len()
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { if w := s.fc.onRead(size - uint32(dataLen)); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
} }
} }
// TODO(bradfitz, zhaoq): A copy is required here because there is no if dataLen > 0 {
// guarantee f.Data() is consumed before the arrival of next frame. f.data.Ref()
// Can this copy be eliminated? s.write(recvMsg{buffer: f.data})
if len(f.Data()) > 0 {
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
if f.StreamEnded() { if f.StreamEnded() {
@@ -979,13 +934,12 @@ func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD)
return headerFields return headerFields
} }
func (t *http2Server) checkForHeaderListSize(it any) bool { func (t *http2Server) checkForHeaderListSize(hf []hpack.HeaderField) bool {
if t.maxSendHeaderListSize == nil { if t.maxSendHeaderListSize == nil {
return true return true
} }
hdrFrame := it.(*headerFrame)
var sz int64 var sz int64
for _, f := range hdrFrame.hf { for _, f := range hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
if t.logger.V(logLevel) { if t.logger.V(logLevel) {
t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize)
@@ -996,6 +950,42 @@ func (t *http2Server) checkForHeaderListSize(it any) bool {
return true return true
} }
// writeEarlyAbort sends an early abort response with the given HTTP status and
// gRPC status. If the header list size exceeds the peer's limit, it sends a
// RST_STREAM instead.
func (t *http2Server) writeEarlyAbort(streamID uint32, contentSubtype string, stat *status.Status, httpStatus uint32, rst bool) {
hf := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(stat.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(stat.Message())},
}
if p := istatus.RawStatusProto(stat); len(p.GetDetails()) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err)
}
if err == nil {
hf = append(hf, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)})
}
}
success, _ := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(hf)
}, &earlyAbortStream{
streamID: streamID,
rst: rst,
hf: hf,
})
if !success {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeInternal,
onWrite: func() {},
})
}
}
func (t *http2Server) streamContextErr(s *ServerStream) error { func (t *http2Server) streamContextErr(s *ServerStream) error {
select { select {
case <-t.done: case <-t.done:
@@ -1051,7 +1041,7 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
endStream: false, endStream: false,
onWrite: t.setResetPingStrikes, onWrite: t.setResetPingStrikes,
} }
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf) success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf.hf) }, hf)
if !success { if !success {
if err != nil { if err != nil {
return err return err
@@ -1059,14 +1049,13 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
t.closeStream(s, true, http2.ErrCodeInternal, false) t.closeStream(s, true, http2.ErrCodeInternal, false)
return ErrHeaderListSizeLimitViolation return ErrHeaderListSizeLimitViolation
} }
for _, sh := range t.stats { if t.stats != nil {
// Note: Headers are compressed with hpack after this call returns. // Note: Headers are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
outHeader := &stats.OutHeader{ t.stats.HandleRPC(s.Context(), &stats.OutHeader{
Header: s.header.Copy(), Header: s.header.Copy(),
Compression: s.sendCompress, Compression: s.sendCompress,
} })
sh.HandleRPC(s.Context(), outHeader)
} }
return nil return nil
} }
@@ -1122,7 +1111,7 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
} }
success, err := t.controlBuf.executeAndPut(func() bool { success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader) return t.checkForHeaderListSize(trailingHeader.hf)
}, nil) }, nil)
if !success { if !success {
if err != nil { if err != nil {
@@ -1134,10 +1123,10 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
// Send a RST_STREAM after the trailers if the client has not already half-closed. // Send a RST_STREAM after the trailers if the client has not already half-closed.
rst := s.getState() == streamActive rst := s.getState() == streamActive
t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true)
for _, sh := range t.stats { if t.stats != nil {
// Note: The trailer fields are compressed with hpack after this call returns. // Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
sh.HandleRPC(s.Context(), &stats.OutTrailer{ t.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(), Trailer: s.trailer.Copy(),
}) })
} }
@@ -1305,7 +1294,8 @@ func (t *http2Server) Close(err error) {
// deleteStream deletes the stream s from transport's active streams. // deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) { func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
t.mu.Lock() t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok { _, isActive := t.activeStreams[s.id]
if isActive {
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 { if len(t.activeStreams) == 0 {
t.idle = time.Now() t.idle = time.Now()
@@ -1313,7 +1303,7 @@ func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
} }
t.mu.Unlock() t.mu.Unlock()
if channelz.IsOn() { if isActive && channelz.IsOn() {
if eosReceived { if eosReceived {
t.channelz.SocketMetrics.StreamsSucceeded.Add(1) t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else { } else {
@@ -1353,10 +1343,10 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo
// called to interrupt the potential blocking on other goroutines. // called to interrupt the potential blocking on other goroutines.
s.cancel() s.cancel()
oldState := s.swapState(streamDone) // We can't return early even if the stream's state is "done" as the state
if oldState == streamDone { // might have been set by the `finishStream` method. Deleting the stream via
return // `finishStream` can get blocked on flow control.
} s.swapState(streamDone)
t.deleteStream(s, eosReceived) t.deleteStream(s, eosReceived)
t.controlBuf.put(&cleanupStream{ t.controlBuf.put(&cleanupStream{

View File

@@ -25,7 +25,6 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@@ -37,6 +36,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
) )
const ( const (
@@ -300,11 +300,11 @@ type bufWriter struct {
buf []byte buf []byte
offset int offset int
batchSize int batchSize int
conn net.Conn conn io.Writer
err error err error
} }
func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter { func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter {
w := &bufWriter{ w := &bufWriter{
batchSize: batchSize, batchSize: batchSize,
conn: conn, conn: conn,
@@ -388,15 +388,29 @@ func toIOError(err error) error {
return ioError{error: err} return ioError{error: err}
} }
type parsedDataFrame struct {
http2.FrameHeader
data mem.Buffer
}
func (df *parsedDataFrame) StreamEnded() bool {
return df.FrameHeader.Flags.Has(http2.FlagDataEndStream)
}
type framer struct { type framer struct {
writer *bufWriter writer *bufWriter
fr *http2.Framer fr *http2.Framer
headerBuf []byte // cached slice for framer headers to reduce heap allocs.
reader io.Reader
dataFrame parsedDataFrame // Cached data frame to avoid heap allocations.
pool mem.BufferPool
errDetail error
} }
var writeBufferPoolMap = make(map[int]*sync.Pool) var writeBufferPoolMap = make(map[int]*sync.Pool)
var writeBufferMutex sync.Mutex var writeBufferMutex sync.Mutex
func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer { func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer {
if writeBufferSize < 0 { if writeBufferSize < 0 {
writeBufferSize = 0 writeBufferSize = 0
} }
@@ -412,6 +426,8 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
f := &framer{ f := &framer{
writer: w, writer: w,
fr: http2.NewFramer(w, r), fr: http2.NewFramer(w, r),
reader: r,
pool: memPool,
} }
f.fr.SetMaxReadFrameSize(http2MaxFrameLen) f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
// Opt-in to Frame reuse API on framer to reduce garbage. // Opt-in to Frame reuse API on framer to reduce garbage.
@@ -422,6 +438,146 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
return f return f
} }
// writeData writes a DATA frame.
//
// It is the caller's responsibility not to violate the maximum frame size.
func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error {
var flags http2.Flags
if endStream {
flags = http2.FlagDataEndStream
}
length := uint32(0)
for _, d := range data {
length += uint32(len(d))
}
// TODO: Replace the header write with the framer API being added in
// https://github.com/golang/go/issues/66655.
f.headerBuf = append(f.headerBuf[:0],
byte(length>>16),
byte(length>>8),
byte(length),
byte(http2.FrameData),
byte(flags),
byte(streamID>>24),
byte(streamID>>16),
byte(streamID>>8),
byte(streamID))
if _, err := f.writer.Write(f.headerBuf); err != nil {
return err
}
for _, d := range data {
if _, err := f.writer.Write(d); err != nil {
return err
}
}
return nil
}
// readFrame reads a single frame. The returned Frame is only valid
// until the next call to readFrame.
func (f *framer) readFrame() (any, error) {
f.errDetail = nil
fh, err := f.fr.ReadFrameHeader()
if err != nil {
f.errDetail = f.fr.ErrorDetail()
return nil, err
}
// Read the data frame directly from the underlying io.Reader to avoid
// copies.
if fh.Type == http2.FrameData {
err = f.readDataFrame(fh)
return &f.dataFrame, err
}
fr, err := f.fr.ReadFrameForHeader(fh)
if err != nil {
f.errDetail = f.fr.ErrorDetail()
return nil, err
}
return fr, err
}
// errorDetail returns a more detailed error of the last error
// returned by framer.readFrame. For instance, if readFrame
// returns a StreamError with code PROTOCOL_ERROR, errorDetail
// will say exactly what was invalid. errorDetail is not guaranteed
// to return a non-nil value.
// errorDetail is reset after the next call to readFrame.
func (f *framer) errorDetail() error {
return f.errDetail
}
func (f *framer) readDataFrame(fh http2.FrameHeader) (err error) {
if fh.StreamID == 0 {
// DATA frames MUST be associated with a stream. If a
// DATA frame is received whose stream identifier
// field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
f.errDetail = errors.New("DATA frame with stream ID 0")
return http2.ConnectionError(http2.ErrCodeProtocol)
}
// Converting a *[]byte to a mem.SliceBuffer incurs a heap allocation. This
// conversion is performed by mem.NewBuffer. To avoid the extra allocation
// a []byte is allocated directly if required and cast to a mem.SliceBuffer.
var buf []byte
// poolHandle is the pointer returned by the buffer pool (if it's used.).
var poolHandle *[]byte
useBufferPool := !mem.IsBelowBufferPoolingThreshold(int(fh.Length))
if useBufferPool {
poolHandle = f.pool.Get(int(fh.Length))
buf = *poolHandle
defer func() {
if err != nil {
f.pool.Put(poolHandle)
}
}()
} else {
buf = make([]byte, int(fh.Length))
}
if fh.Flags.Has(http2.FlagDataPadded) {
if fh.Length == 0 {
return io.ErrUnexpectedEOF
}
// This initial 1-byte read can be inefficient for unbuffered readers,
// but it allows the rest of the payload to be read directly to the
// start of the destination slice. This makes it easy to return the
// original slice back to the buffer pool.
if _, err := io.ReadFull(f.reader, buf[:1]); err != nil {
return err
}
padSize := buf[0]
buf = buf[:len(buf)-1]
if int(padSize) > len(buf) {
// If the length of the padding is greater than the
// length of the frame payload, the recipient MUST
// treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610
f.errDetail = errors.New("pad size larger than data payload")
return http2.ConnectionError(http2.ErrCodeProtocol)
}
if _, err := io.ReadFull(f.reader, buf); err != nil {
return err
}
buf = buf[:len(buf)-int(padSize)]
} else if _, err := io.ReadFull(f.reader, buf); err != nil {
return err
}
f.dataFrame.FrameHeader = fh
if useBufferPool {
// Update the handle to point to the (potentially re-sliced) buf.
*poolHandle = buf
f.dataFrame.data = mem.NewBuffer(poolHandle, f.pool)
} else {
f.dataFrame.data = mem.SliceBuffer(buf)
}
return nil
}
func (df *parsedDataFrame) Header() http2.FrameHeader {
return df.FrameHeader
}
func getWriteBufferPool(size int) *sync.Pool { func getWriteBufferPool(size int) *sync.Pool {
writeBufferMutex.Lock() writeBufferMutex.Lock()
defer writeBufferMutex.Unlock() defer writeBufferMutex.Unlock()

View File

@@ -32,7 +32,7 @@ import (
// ServerStream implements streaming functionality for a gRPC server. // ServerStream implements streaming functionality for a gRPC server.
type ServerStream struct { type ServerStream struct {
*Stream // Embed for common stream functionality. Stream // Embed for common stream functionality.
st internalServerTransport st internalServerTransport
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
@@ -43,12 +43,13 @@ type ServerStream struct {
// Holds compressor names passed in grpc-accept-encoding metadata from the // Holds compressor names passed in grpc-accept-encoding metadata from the
// client. // client.
clientAdvertisedCompressors string clientAdvertisedCompressors string
headerWireLength int
// hdrMu protects outgoing header and trailer metadata. // hdrMu protects outgoing header and trailer metadata.
hdrMu sync.Mutex hdrMu sync.Mutex
header metadata.MD // the outgoing header metadata. Updated by WriteHeader. header metadata.MD // the outgoing header metadata. Updated by WriteHeader.
headerSent atomic.Bool // atomically set when the headers are sent out. headerSent atomic.Bool // atomically set when the headers are sent out.
headerWireLength int
} }
// Read reads an n byte message from the input stream. // Read reads an n byte message from the input stream.
@@ -178,3 +179,11 @@ func (s *ServerStream) SetTrailer(md metadata.MD) error {
s.hdrMu.Unlock() s.hdrMu.Unlock()
return nil return nil
} }
func (s *ServerStream) requestRead(n int) {
s.st.adjustWindow(s, uint32(n))
}
func (s *ServerStream) updateWindow(n int) {
s.st.updateWindow(s, uint32(n))
}

View File

@@ -68,11 +68,11 @@ type recvBuffer struct {
err error err error
} }
func newRecvBuffer() *recvBuffer { // init allows a recvBuffer to be initialized in-place, which is useful
b := &recvBuffer{ // for resetting a buffer or for avoiding a heap allocation when the buffer
c: make(chan recvMsg, 1), // is embedded in another struct.
} func (b *recvBuffer) init() {
return b b.c = make(chan recvMsg, 1)
} }
func (b *recvBuffer) put(r recvMsg) { func (b *recvBuffer) put(r recvMsg) {
@@ -123,7 +123,8 @@ func (b *recvBuffer) get() <-chan recvMsg {
// recvBufferReader implements io.Reader interface to read the data from // recvBufferReader implements io.Reader interface to read the data from
// recvBuffer. // recvBuffer.
type recvBufferReader struct { type recvBufferReader struct {
closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata. _ noCopy
clientStream *ClientStream // The client transport stream is closed with a status representing ctx.Err() and nil trailer metadata.
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance). ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
@@ -139,7 +140,7 @@ func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
n, r.last = mem.ReadUnsafe(header, r.last) n, r.last = mem.ReadUnsafe(header, r.last)
return n, nil return n, nil
} }
if r.closeStream != nil { if r.clientStream != nil {
n, r.err = r.readMessageHeaderClient(header) n, r.err = r.readMessageHeaderClient(header)
} else { } else {
n, r.err = r.readMessageHeader(header) n, r.err = r.readMessageHeader(header)
@@ -164,7 +165,7 @@ func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
} }
return buf, nil return buf, nil
} }
if r.closeStream != nil { if r.clientStream != nil {
buf, r.err = r.readClient(n) buf, r.err = r.readClient(n)
} else { } else {
buf, r.err = r.read(n) buf, r.err = r.read(n)
@@ -209,7 +210,7 @@ func (r *recvBufferReader) readMessageHeaderClient(header []byte) (n int, err er
// TODO: delaying ctx error seems like a unnecessary side effect. What // TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error // we really want is to mark the stream as done, and return ctx error
// faster. // faster.
r.closeStream(ContextErr(r.ctx.Err())) r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get() m := <-r.recv.get()
return r.readMessageHeaderAdditional(m, header) return r.readMessageHeaderAdditional(m, header)
case m := <-r.recv.get(): case m := <-r.recv.get():
@@ -236,7 +237,7 @@ func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// TODO: delaying ctx error seems like a unnecessary side effect. What // TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error // we really want is to mark the stream as done, and return ctx error
// faster. // faster.
r.closeStream(ContextErr(r.ctx.Err())) r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get() m := <-r.recv.get()
return r.readAdditional(m, n) return r.readAdditional(m, n)
case m := <-r.recv.get(): case m := <-r.recv.get():
@@ -285,27 +286,32 @@ const (
// Stream represents an RPC in the transport layer. // Stream represents an RPC in the transport layer.
type Stream struct { type Stream struct {
id uint32
ctx context.Context // the associated context of the stream ctx context.Context // the associated context of the stream
method string // the associated RPC method of the stream method string // the associated RPC method of the stream
recvCompress string recvCompress string
sendCompress string sendCompress string
buf *recvBuffer
trReader *transportReader
fc *inFlow
wq *writeQuota
// Callback to state application's intentions to read data. This readRequester readRequester
// is used to adjust flow control, if needed.
requestRead func(int)
state streamState
// contentSubtype is the content-subtype for requests. // contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined. // this must be lowercase or the behavior is undefined.
contentSubtype string contentSubtype string
trailer metadata.MD // the key-value map of trailer metadata. trailer metadata.MD // the key-value map of trailer metadata.
// Non-pointer fields are at the end to optimize GC performance.
state streamState
id uint32
buf recvBuffer
trReader transportReader
fc inFlow
wq writeQuota
}
// readRequester is used to state application's intentions to read data. This
// is used to adjust flow control, if needed.
type readRequester interface {
requestRead(int)
} }
func (s *Stream) swapState(st streamState) streamState { func (s *Stream) swapState(st streamState) streamState {
@@ -355,7 +361,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) {
if er := s.trReader.er; er != nil { if er := s.trReader.er; er != nil {
return er return er
} }
s.requestRead(len(header)) s.readRequester.requestRead(len(header))
for len(header) != 0 { for len(header) != 0 {
n, err := s.trReader.ReadMessageHeader(header) n, err := s.trReader.ReadMessageHeader(header)
header = header[n:] header = header[n:]
@@ -372,13 +378,29 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) {
return nil return nil
} }
// ceil returns the ceil after dividing the numerator and denominator while
// avoiding integer overflows.
func ceil(numerator, denominator int) int {
if numerator == 0 {
return 0
}
return (numerator-1)/denominator + 1
}
// Read reads n bytes from the wire for this stream. // Read reads n bytes from the wire for this stream.
func (s *Stream) read(n int) (data mem.BufferSlice, err error) { func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier // Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil { if er := s.trReader.er; er != nil {
return nil, er return nil, er
} }
s.requestRead(n) // gRPC Go accepts data frames with a maximum length of 16KB. Larger
// messages must be split into multiple frames. We pre-allocate the
// buffer to avoid resizing during the read loop, but cap the initial
// capacity to 128 frames (2MB) to prevent over-allocation or panics
// when reading extremely large streams.
allocCap := min(ceil(n, http2MaxFrameLen), 128)
data = make(mem.BufferSlice, 0, allocCap)
s.readRequester.requestRead(n)
for n != 0 { for n != 0 {
buf, err := s.trReader.Read(n) buf, err := s.trReader.Read(n)
var bufLen int var bufLen int
@@ -401,16 +423,34 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
return data, nil return data, nil
} }
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct {
}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// transportReader reads all the data available for this Stream from the transport and // transportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream. // passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if // The error is io.EOF when the stream is done or another non-nil error if
// the stream broke. // the stream broke.
type transportReader struct { type transportReader struct {
reader *recvBufferReader _ noCopy
// The handler to control the window update procedure for both this // The handler to control the window update procedure for both this
// particular stream and the associated transport. // particular stream and the associated transport.
windowHandler func(int) windowHandler windowHandler
er error er error
reader recvBufferReader
}
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
type windowHandler interface {
updateWindow(int)
} }
func (t *transportReader) ReadMessageHeader(header []byte) (int, error) { func (t *transportReader) ReadMessageHeader(header []byte) (int, error) {
@@ -419,7 +459,7 @@ func (t *transportReader) ReadMessageHeader(header []byte) (int, error) {
t.er = err t.er = err
return 0, err return 0, err
} }
t.windowHandler(n) t.windowHandler.updateWindow(n)
return n, nil return n, nil
} }
@@ -429,7 +469,7 @@ func (t *transportReader) Read(n int) (mem.Buffer, error) {
t.er = err t.er = err
return buf, err return buf, err
} }
t.windowHandler(buf.Len()) t.windowHandler.updateWindow(buf.Len())
return buf, nil return buf, nil
} }
@@ -454,7 +494,7 @@ type ServerConfig struct {
ConnectionTimeout time.Duration ConnectionTimeout time.Duration
Credentials credentials.TransportCredentials Credentials credentials.TransportCredentials
InTapHandle tap.ServerInHandle InTapHandle tap.ServerInHandle
StatsHandlers []stats.Handler StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters KeepaliveParams keepalive.ServerParameters
KeepalivePolicy keepalive.EnforcementPolicy KeepalivePolicy keepalive.EnforcementPolicy
InitialWindowSize int32 InitialWindowSize int32
@@ -529,6 +569,12 @@ type CallHdr struct {
// outbound message. // outbound message.
SendCompress string SendCompress string
// AcceptedCompressors overrides the grpc-accept-encoding header for this
// call. When nil, the transport advertises the default set of registered
// compressors. A non-nil pointer overrides that value (including the empty
// string to advertise none).
AcceptedCompressors *string
// Creds specifies credentials.PerRPCCredentials for a call. // Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials Creds credentials.PerRPCCredentials
@@ -544,9 +590,14 @@ type CallHdr struct {
DoneFunc func() // called when the stream is finished DoneFunc func() // called when the stream is finished
// Authority is used to explicitly override the `:authority` header. If set, // Authority is used to explicitly override the `:authority` header.
// this value takes precedence over the Host field and will be used as the //
// value for the `:authority` header. // This value comes from one of two sources:
// 1. The `CallAuthority` call option, if specified by the user.
// 2. An override provided by the LB picker (e.g. xDS authority rewriting).
//
// The `CallAuthority` call option always takes precedence over the LB
// picker override.
Authority string Authority string
} }
@@ -566,7 +617,7 @@ type ClientTransport interface {
GracefulClose() GracefulClose()
// NewStream creates a Stream for an RPC. // NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error)
// Error returns a channel that is closed when some I/O error // Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor // happens. Typically the caller should have a goroutine to monitor
@@ -584,8 +635,9 @@ type ClientTransport interface {
// with a human readable string with debug info. // with a human readable string with debug info.
GetGoAwayReason() (GoAwayReason, string) GetGoAwayReason() (GoAwayReason, string)
// RemoteAddr returns the remote network address. // Peer returns information about the peer associated with the Transport.
RemoteAddr() net.Addr // The returned information includes authentication and network address details.
Peer() *peer.Peer
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
@@ -615,6 +667,8 @@ type internalServerTransport interface {
write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error
writeStatus(s *ServerStream, st *status.Status) error writeStatus(s *ServerStream, st *status.Status) error
incrMsgRecv() incrMsgRecv()
adjustWindow(s *ServerStream, n uint32)
updateWindow(s *ServerStream, n uint32)
} }
// connectionErrorf creates an ConnectionError with the specified error description. // connectionErrorf creates an ConnectionError with the specified error description.

View File

@@ -32,12 +32,17 @@ type BufferPool interface {
Get(length int) *[]byte Get(length int) *[]byte
// Put returns a buffer to the pool. // Put returns a buffer to the pool.
//
// The provided pointer must hold a prefix of the buffer obtained via
// BufferPool.Get to ensure the buffer's entire capacity can be re-used.
Put(*[]byte) Put(*[]byte)
} }
const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2.
var defaultBufferPoolSizes = []int{ var defaultBufferPoolSizes = []int{
256, 256,
4 << 10, // 4KB (go page size) goPageSize,
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC) 16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy) 32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB 1 << 20, // 1MB
@@ -48,7 +53,7 @@ var defaultBufferPool BufferPool
func init() { func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...) defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) { internal.SetDefaultBufferPool = func(pool BufferPool) {
defaultBufferPool = pool defaultBufferPool = pool
} }
@@ -118,7 +123,11 @@ type sizedBufferPool struct {
} }
func (p *sizedBufferPool) Get(size int) *[]byte { func (p *sizedBufferPool) Get(size int) *[]byte {
buf := p.pool.Get().(*[]byte) buf, ok := p.pool.Get().(*[]byte)
if !ok {
buf := make([]byte, size, p.defaultSize)
return &buf
}
b := *buf b := *buf
clear(b[:cap(b)]) clear(b[:cap(b)])
*buf = b[:size] *buf = b[:size]
@@ -137,12 +146,6 @@ func (p *sizedBufferPool) Put(buf *[]byte) {
func newSizedBufferPool(size int) *sizedBufferPool { func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{ return &sizedBufferPool{
pool: sync.Pool{
New: func() any {
buf := make([]byte, size)
return &buf
},
},
defaultSize: size, defaultSize: size,
} }
} }
@@ -160,6 +163,7 @@ type simpleBufferPool struct {
func (p *simpleBufferPool) Get(size int) *[]byte { func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte) bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size { if ok && cap(*bs) >= size {
clear((*bs)[:cap(*bs)])
*bs = (*bs)[:size] *bs = (*bs)[:size]
return bs return bs
} }
@@ -170,7 +174,14 @@ func (p *simpleBufferPool) Get(size int) *[]byte {
p.pool.Put(bs) p.pool.Put(bs)
} }
b := make([]byte, size) // If we're going to allocate, round up to the nearest page. This way if
// requests frequently arrive with small variation we don't allocate
// repeatedly if we get unlucky and they increase over time. By default we
// only allocate here if size > 1MiB. Because goPageSize is a power of 2, we
// can round up efficiently.
allocSize := (size + goPageSize - 1) & ^(goPageSize - 1)
b := make([]byte, size, allocSize)
return &b return &b
} }

View File

@@ -19,6 +19,7 @@
package mem package mem
import ( import (
"fmt"
"io" "io"
) )
@@ -117,43 +118,36 @@ func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
// Reader returns a new Reader for the input slice after taking references to // Reader returns a new Reader for the input slice after taking references to
// each underlying buffer. // each underlying buffer.
func (s BufferSlice) Reader() Reader { func (s BufferSlice) Reader() *Reader {
s.Ref() s.Ref()
return &sliceReader{ return &Reader{
data: s, data: s,
len: s.Len(), len: s.Len(),
} }
} }
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface // Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
// with other parts systems. It also provides an additional convenience method // with other systems.
// Remaining(), which returns the number of unread bytes remaining in the slice. //
// Buffers will be freed as they are read. // Buffers will be freed as they are read.
type Reader interface { //
io.Reader // A Reader can be constructed from a BufferSlice; alternatively the zero value
io.ByteReader // of a Reader may be used after calling Reset on it.
// Close frees the underlying BufferSlice and never returns an error. Subsequent type Reader struct {
// calls to Read will return (0, io.EOF).
Close() error
// Remaining returns the number of unread bytes remaining in the slice.
Remaining() int
// Reset frees the currently held buffer slice and starts reading from the
// provided slice. This allows reusing the reader object.
Reset(s BufferSlice)
}
type sliceReader struct {
data BufferSlice data BufferSlice
len int len int
// The index into data[0].ReadOnlyData(). // The index into data[0].ReadOnlyData().
bufferIdx int bufferIdx int
} }
func (r *sliceReader) Remaining() int { // Remaining returns the number of unread bytes remaining in the slice.
func (r *Reader) Remaining() int {
return r.len return r.len
} }
func (r *sliceReader) Reset(s BufferSlice) { // Reset frees the currently held buffer slice and starts reading from the
// provided slice. This allows reusing the reader object.
func (r *Reader) Reset(s BufferSlice) {
r.data.Free() r.data.Free()
s.Ref() s.Ref()
r.data = s r.data = s
@@ -161,14 +155,16 @@ func (r *sliceReader) Reset(s BufferSlice) {
r.bufferIdx = 0 r.bufferIdx = 0
} }
func (r *sliceReader) Close() error { // Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
func (r *Reader) Close() error {
r.data.Free() r.data.Free()
r.data = nil r.data = nil
r.len = 0 r.len = 0
return nil return nil
} }
func (r *sliceReader) freeFirstBufferIfEmpty() bool { func (r *Reader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) { if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
return false return false
} }
@@ -179,7 +175,7 @@ func (r *sliceReader) freeFirstBufferIfEmpty() bool {
return true return true
} }
func (r *sliceReader) Read(buf []byte) (n int, _ error) { func (r *Reader) Read(buf []byte) (n int, _ error) {
if r.len == 0 { if r.len == 0 {
return 0, io.EOF return 0, io.EOF
} }
@@ -202,7 +198,8 @@ func (r *sliceReader) Read(buf []byte) (n int, _ error) {
return n, nil return n, nil
} }
func (r *sliceReader) ReadByte() (byte, error) { // ReadByte reads a single byte.
func (r *Reader) ReadByte() (byte, error) {
if r.len == 0 { if r.len == 0 {
return 0, io.EOF return 0, io.EOF
} }
@@ -290,3 +287,59 @@ nextBuffer:
} }
} }
} }
// Discard skips the next n bytes, returning the number of bytes discarded.
//
// It frees buffers as they are fully consumed.
//
// If Discard skips fewer than n bytes, it also returns an error.
func (r *Reader) Discard(n int) (discarded int, err error) {
total := n
for n > 0 && r.len > 0 {
curData := r.data[0].ReadOnlyData()
curSize := min(n, len(curData)-r.bufferIdx)
n -= curSize
r.len -= curSize
r.bufferIdx += curSize
if r.bufferIdx >= len(curData) {
r.data[0].Free()
r.data = r.data[1:]
r.bufferIdx = 0
}
}
discarded = total - n
if n > 0 {
return discarded, fmt.Errorf("insufficient bytes in reader")
}
return discarded, nil
}
// Peek returns the next n bytes without advancing the reader.
//
// Peek appends results to the provided res slice and returns the updated slice.
// This pattern allows re-using the storage of res if it has sufficient
// capacity.
//
// The returned subslices are views into the underlying buffers and are only
// valid until the reader is advanced past the corresponding buffer.
//
// If Peek returns fewer than n bytes, it also returns an error.
func (r *Reader) Peek(n int, res [][]byte) ([][]byte, error) {
for i := 0; n > 0 && i < len(r.data); i++ {
curData := r.data[i].ReadOnlyData()
start := 0
if i == 0 {
start = r.bufferIdx
}
curSize := min(n, len(curData)-start)
if curSize == 0 {
continue
}
res = append(res, curData[start:start+curSize])
n -= curSize
}
if n > 0 {
return nil, fmt.Errorf("insufficient bytes in reader")
}
return res, nil
}

View File

@@ -62,7 +62,6 @@ var (
bufferPoolingThreshold = 1 << 10 bufferPoolingThreshold = 1 << 10
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }} bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
) )
// IsBelowBufferPoolingThreshold returns true if the given size is less than or // IsBelowBufferPoolingThreshold returns true if the given size is less than or
@@ -73,9 +72,19 @@ func IsBelowBufferPoolingThreshold(size int) bool {
} }
type buffer struct { type buffer struct {
origData *[]byte refs atomic.Int32
data []byte data []byte
refs *atomic.Int32
// rootBuf is the buffer responsible for returning origData to the pool
// once the reference count drops to 0.
//
// When a buffer is split, the new buffer inherits the rootBuf of the
// original and increments the root's reference count. For the
// initial buffer (the root), this field points to itself.
rootBuf *buffer
// The following fields are only set for root buffers.
origData *[]byte
pool BufferPool pool BufferPool
} }
@@ -103,8 +112,8 @@ func NewBuffer(data *[]byte, pool BufferPool) Buffer {
b.origData = data b.origData = data
b.data = *data b.data = *data
b.pool = pool b.pool = pool
b.refs = refObjectPool.Get().(*atomic.Int32) b.rootBuf = b
b.refs.Add(1) b.refs.Store(1)
return b return b
} }
@@ -127,42 +136,44 @@ func Copy(data []byte, pool BufferPool) Buffer {
} }
func (b *buffer) ReadOnlyData() []byte { func (b *buffer) ReadOnlyData() []byte {
if b.refs == nil { if b.rootBuf == nil {
panic("Cannot read freed buffer") panic("Cannot read freed buffer")
} }
return b.data return b.data
} }
func (b *buffer) Ref() { func (b *buffer) Ref() {
if b.refs == nil { if b.refs.Add(1) <= 1 {
panic("Cannot ref freed buffer") panic("Cannot ref freed buffer")
} }
b.refs.Add(1)
} }
func (b *buffer) Free() { func (b *buffer) Free() {
if b.refs == nil { refs := b.refs.Add(-1)
if refs < 0 {
panic("Cannot free freed buffer") panic("Cannot free freed buffer")
} }
if refs > 0 {
refs := b.refs.Add(-1)
switch {
case refs > 0:
return return
case refs == 0: }
b.data = nil
if b.rootBuf == b {
// This buffer is the owner of the data slice and its ref count reached
// 0, free the slice.
if b.pool != nil { if b.pool != nil {
b.pool.Put(b.origData) b.pool.Put(b.origData)
b.pool = nil
}
b.origData = nil
} else {
// This buffer doesn't own the data slice, decrement a ref on the root
// buffer.
b.rootBuf.Free()
} }
refObjectPool.Put(b.refs) b.rootBuf = nil
b.origData = nil
b.data = nil
b.refs = nil
b.pool = nil
bufferObjectPool.Put(b) bufferObjectPool.Put(b)
default:
panic("Cannot free freed buffer")
}
} }
func (b *buffer) Len() int { func (b *buffer) Len() int {
@@ -170,16 +181,14 @@ func (b *buffer) Len() int {
} }
func (b *buffer) split(n int) (Buffer, Buffer) { func (b *buffer) split(n int) (Buffer, Buffer) {
if b.refs == nil { if b.rootBuf == nil || b.rootBuf.refs.Add(1) <= 1 {
panic("Cannot split freed buffer") panic("Cannot split freed buffer")
} }
b.refs.Add(1)
split := newBuffer() split := newBuffer()
split.origData = b.origData
split.data = b.data[n:] split.data = b.data[n:]
split.refs = b.refs split.rootBuf = b.rootBuf
split.pool = b.pool split.refs.Store(1)
b.data = b.data[:n] b.data = b.data[:n]
@@ -187,7 +196,7 @@ func (b *buffer) split(n int) (Buffer, Buffer) {
} }
func (b *buffer) read(buf []byte) (int, Buffer) { func (b *buffer) read(buf []byte) (int, Buffer) {
if b.refs == nil { if b.rootBuf == nil {
panic("Cannot read freed buffer") panic("Cannot read freed buffer")
} }

View File

@@ -29,7 +29,6 @@ import (
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@@ -49,13 +48,10 @@ type pickerGeneration struct {
type pickerWrapper struct { type pickerWrapper struct {
// If pickerGen holds a nil pointer, the pickerWrapper is closed. // If pickerGen holds a nil pointer, the pickerWrapper is closed.
pickerGen atomic.Pointer[pickerGeneration] pickerGen atomic.Pointer[pickerGeneration]
statsHandlers []stats.Handler // to record blocking picker calls
} }
func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper { func newPickerWrapper() *pickerWrapper {
pw := &pickerWrapper{ pw := &pickerWrapper{}
statsHandlers: statsHandlers,
}
pw.pickerGen.Store(&pickerGeneration{ pw.pickerGen.Store(&pickerGeneration{
blockingCh: make(chan struct{}), blockingCh: make(chan struct{}),
}) })
@@ -93,6 +89,12 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) {
} }
} }
type pick struct {
transport transport.ClientTransport // the selected transport
result balancer.PickResult // the contents of the pick from the LB policy
blocked bool // set if a picker call queued for a new picker
}
// pick returns the transport that will be used for the RPC. // pick returns the transport that will be used for the RPC.
// It may block in the following cases: // It may block in the following cases:
// - there's no picker // - there's no picker
@@ -100,15 +102,16 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) {
// - the current picker returns other errors and failfast is false. // - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY // - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated. // When one of these situations happens, pick blocks until the picker gets updated.
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) { func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (pick, error) {
var ch chan struct{} var ch chan struct{}
var lastPickErr error var lastPickErr error
pickBlocked := false
for { for {
pg := pw.pickerGen.Load() pg := pw.pickerGen.Load()
if pg == nil { if pg == nil {
return nil, balancer.PickResult{}, ErrClientConnClosing return pick{}, ErrClientConnClosing
} }
if pg.picker == nil { if pg.picker == nil {
ch = pg.blockingCh ch = pg.blockingCh
@@ -127,9 +130,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
} }
switch ctx.Err() { switch ctx.Err() {
case context.DeadlineExceeded: case context.DeadlineExceeded:
return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr) return pick{}, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled: case context.Canceled:
return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr) return pick{}, status.Error(codes.Canceled, errStr)
} }
case <-ch: case <-ch:
} }
@@ -145,9 +148,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
// In the second case, the only way it will get to this conditional is // In the second case, the only way it will get to this conditional is
// if there is a new picker. // if there is a new picker.
if ch != nil { if ch != nil {
for _, sh := range pw.statsHandlers { pickBlocked = true
sh.HandleRPC(ctx, &stats.PickerUpdated{})
}
} }
ch = pg.blockingCh ch = pg.blockingCh
@@ -164,7 +165,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if istatus.IsRestrictedControlPlaneCode(st) { if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err) err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
} }
return nil, balancer.PickResult{}, dropError{error: err} return pick{}, dropError{error: err}
} }
// For all other errors, wait for ready RPCs should block and other // For all other errors, wait for ready RPCs should block and other
// RPCs should fail with unavailable. // RPCs should fail with unavailable.
@@ -172,7 +173,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
lastPickErr = err lastPickErr = err
continue continue
} }
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error()) return pick{}, status.Error(codes.Unavailable, err.Error())
} }
acbw, ok := pickResult.SubConn.(*acBalancerWrapper) acbw, ok := pickResult.SubConn.(*acBalancerWrapper)
@@ -183,9 +184,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if t := acbw.ac.getReadyTransport(); t != nil { if t := acbw.ac.getReadyTransport(); t != nil {
if channelz.IsOn() { if channelz.IsOn() {
doneChannelzWrapper(acbw, &pickResult) doneChannelzWrapper(acbw, &pickResult)
return t, pickResult, nil
} }
return t, pickResult, nil return pick{transport: t, result: pickResult, blocked: pickBlocked}, nil
} }
if pickResult.Done != nil { if pickResult.Done != nil {
// Calling done with nil error, no bytes sent and no bytes received. // Calling done with nil error, no bytes sent and no bytes received.

View File

@@ -47,9 +47,6 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
} }
// check if the context has the relevant information to prepareMsg // check if the context has the relevant information to prepareMsg
if rpcInfo.preloaderInfo == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil")
}
if rpcInfo.preloaderInfo.codec == nil { if rpcInfo.preloaderInfo.codec == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil") return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil")
} }

View File

@@ -182,6 +182,7 @@ type BuildOptions struct {
// An Endpoint is one network endpoint, or server, which may have multiple // An Endpoint is one network endpoint, or server, which may have multiple
// addresses with which it can be accessed. // addresses with which it can be accessed.
// TODO(i/8773) : make resolver.Endpoint and resolver.Address immutable
type Endpoint struct { type Endpoint struct {
// Addresses contains a list of addresses used to access this endpoint. // Addresses contains a list of addresses used to access this endpoint.
Addresses []Address Addresses []Address
@@ -332,6 +333,11 @@ type AuthorityOverrider interface {
// OverrideAuthority returns the authority to use for a ClientConn with the // OverrideAuthority returns the authority to use for a ClientConn with the
// given target. The implementation must generate it without blocking, // given target. The implementation must generate it without blocking,
// typically in line, and must keep it unchanged. // typically in line, and must keep it unchanged.
//
// The returned string must be a valid ":authority" header value, i.e. be
// encoded according to
// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2) as
// necessary.
OverrideAuthority(Target) string OverrideAuthority(Target) string
} }

View File

@@ -69,6 +69,7 @@ func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error) errCh := make(chan error)
ccr.serializer.TrySchedule(func(ctx context.Context) { ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
errCh <- ctx.Err()
return return
} }
opts := resolver.BuildOptions{ opts := resolver.BuildOptions{

View File

@@ -33,6 +33,8 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -41,6 +43,10 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func init() {
internal.AcceptCompressors = acceptCompressors
}
// Compressor defines the interface gRPC uses to compress a message. // Compressor defines the interface gRPC uses to compress a message.
// //
// Deprecated: use package encoding. // Deprecated: use package encoding.
@@ -161,6 +167,22 @@ type callInfo struct {
maxRetryRPCBufferSize int maxRetryRPCBufferSize int
onFinish []func(err error) onFinish []func(err error)
authority string authority string
acceptedResponseCompressors []string
}
func acceptedCompressorAllows(allowed []string, name string) bool {
if allowed == nil {
return true
}
if name == "" || name == encoding.Identity {
return true
}
for _, a := range allowed {
if a == name {
return true
}
}
return false
} }
func defaultCallInfo() *callInfo { func defaultCallInfo() *callInfo {
@@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo {
} }
} }
func newAcceptedCompressionConfig(names []string) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
var allowed []string
seen := make(map[string]struct{}, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" || name == encoding.Identity {
continue
}
if !grpcutil.IsCompressorNameRegistered(name) {
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
}
if _, dup := seen[name]; dup {
continue
}
seen[name] = struct{}{}
allowed = append(allowed, name)
}
return allowed, nil
}
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from
// a Call after it completes. // a Call after it completes.
type CallOption interface { type CallOption interface {
@@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error {
} }
func (o CompressorCallOption) after(*callInfo, *csAttempt) {} func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
// acceptCompressors returns a CallOption that limits the compression algorithms
// advertised in the grpc-accept-encoding header for response messages.
// Compression algorithms not in the provided list will not be advertised, and
// responses compressed with non-listed algorithms will be rejected.
func acceptCompressors(names ...string) CallOption {
cp := append([]string(nil), names...)
return acceptCompressorsCallOption{names: cp}
}
// acceptCompressorsCallOption is a CallOption that limits response compression.
type acceptCompressorsCallOption struct {
names []string
}
func (o acceptCompressorsCallOption) before(c *callInfo) error {
allowed, err := newAcceptedCompressionConfig(o.names)
if err != nil {
return err
}
c.acceptedResponseCompressors = allowed
return nil
}
func (acceptCompressorsCallOption) after(*callInfo, *csAttempt) {}
// CallContentSubtype returns a CallOption that will set the content-subtype // CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over // for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted // the wire will be "application/grpc+json". The content-subtype is converted
@@ -657,8 +727,20 @@ type streamReader interface {
Read(n int) (mem.BufferSlice, error) Read(n int) (mem.BufferSlice, error)
} }
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct {
}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// parser reads complete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
type parser struct { type parser struct {
_ noCopy
// r is the underlying reader. // r is the underlying reader.
// See the comment on recvMsg for the permissible // See the comment on recvMsg for the permissible
// error types. // error types.
@@ -845,8 +927,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed. // the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments. // TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) {
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize) pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -949,7 +1030,7 @@ func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxR
// Information about RPC // Information about RPC
type rpcInfo struct { type rpcInfo struct {
failfast bool failfast bool
preloaderInfo *compressorInfo preloaderInfo compressorInfo
} }
// Information about Preloader // Information about Preloader
@@ -968,7 +1049,7 @@ type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context { func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{ return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
failfast: failfast, failfast: failfast,
preloaderInfo: &compressorInfo{ preloaderInfo: compressorInfo{
codec: codec, codec: codec,
cp: cp, cp: cp,
comp: comp, comp: comp,

View File

@@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
istats "google.golang.org/grpc/internal/stats" istats "google.golang.org/grpc/internal/stats"
@@ -125,6 +126,7 @@ type serviceInfo struct {
// Server is a gRPC server to serve RPC requests. // Server is a gRPC server to serve RPC requests.
type Server struct { type Server struct {
opts serverOptions opts serverOptions
statsHandler stats.Handler
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
@@ -148,6 +150,8 @@ type Server struct {
serverWorkerChannel chan func() serverWorkerChannel chan func()
serverWorkerChannelClose func() serverWorkerChannelClose func()
strictPathCheckingLogEmitted atomic.Bool
} }
type serverOptions struct { type serverOptions struct {
@@ -694,6 +698,7 @@ func NewServer(opt ...ServerOption) *Server {
s := &Server{ s := &Server{
lis: make(map[net.Listener]bool), lis: make(map[net.Listener]bool),
opts: opts, opts: opts,
statsHandler: istats.NewCombinedHandler(opts.statsHandlers...),
conns: make(map[string]map[transport.ServerTransport]bool), conns: make(map[string]map[transport.ServerTransport]bool),
services: make(map[string]*serviceInfo), services: make(map[string]*serviceInfo),
quit: grpcsync.NewEvent(), quit: grpcsync.NewEvent(),
@@ -921,9 +926,7 @@ func (s *Server) Serve(lis net.Listener) error {
tempDelay = 5 * time.Millisecond tempDelay = 5 * time.Millisecond
} else { } else {
tempDelay *= 2 tempDelay *= 2
} tempDelay = min(tempDelay, 1*time.Second)
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
} }
s.mu.Lock() s.mu.Lock()
s.printf("Accept error: %v; retrying in %v", err, tempDelay) s.printf("Accept error: %v; retrying in %v", err, tempDelay)
@@ -999,7 +1002,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
ConnectionTimeout: s.opts.connectionTimeout, ConnectionTimeout: s.opts.connectionTimeout,
Credentials: s.opts.creds, Credentials: s.opts.creds,
InTapHandle: s.opts.inTapHandle, InTapHandle: s.opts.inTapHandle,
StatsHandlers: s.opts.statsHandlers, StatsHandler: s.statsHandler,
KeepaliveParams: s.opts.keepaliveParams, KeepaliveParams: s.opts.keepaliveParams,
KeepalivePolicy: s.opts.keepalivePolicy, KeepalivePolicy: s.opts.keepalivePolicy,
InitialWindowSize: s.opts.initialWindowSize, InitialWindowSize: s.opts.initialWindowSize,
@@ -1036,18 +1039,18 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) { func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
ctx = transport.SetConnection(ctx, rawConn) ctx = transport.SetConnection(ctx, rawConn)
ctx = peer.NewContext(ctx, st.Peer()) ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers { if s.statsHandler != nil {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ ctx = s.statsHandler.TagConn(ctx, &stats.ConnTagInfo{
RemoteAddr: st.Peer().Addr, RemoteAddr: st.Peer().Addr,
LocalAddr: st.Peer().LocalAddr, LocalAddr: st.Peer().LocalAddr,
}) })
sh.HandleConn(ctx, &stats.ConnBegin{}) s.statsHandler.HandleConn(ctx, &stats.ConnBegin{})
} }
defer func() { defer func() {
st.Close(errors.New("finished serving streams for the server transport")) st.Close(errors.New("finished serving streams for the server transport"))
for _, sh := range s.opts.statsHandlers { if s.statsHandler != nil {
sh.HandleConn(ctx, &stats.ConnEnd{}) s.statsHandler.HandleConn(ctx, &stats.ConnEnd{})
} }
}() }()
@@ -1104,7 +1107,7 @@ var _ http.Handler = (*Server)(nil)
// Notice: This API is EXPERIMENTAL and may be changed or removed in a // Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool) st, err := transport.NewServerHandlerTransport(w, r, s.statsHandler, s.opts.bufferPool)
if err != nil { if err != nil {
// Errors returned from transport.NewServerHandlerTransport have // Errors returned from transport.NewServerHandlerTransport have
// already been written to w. // already been written to w.
@@ -1198,12 +1201,8 @@ func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStrea
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
} }
err = stream.Write(hdr, payload, opts) err = stream.Write(hdr, payload, opts)
if err == nil { if err == nil && s.statsHandler != nil {
if len(s.opts.statsHandlers) != 0 { s.statsHandler.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
for _, sh := range s.opts.statsHandlers {
sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
}
}
} }
return err return err
} }
@@ -1245,16 +1244,15 @@ func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info
} }
func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
shs := s.opts.statsHandlers sh := s.statsHandler
if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if sh != nil || trInfo != nil || channelz.IsOn() {
if channelz.IsOn() { if channelz.IsOn() {
s.incrCallsStarted() s.incrCallsStarted()
} }
var statsBegin *stats.Begin var statsBegin *stats.Begin
for _, sh := range shs { if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{ statsBegin = &stats.Begin{
BeginTime: beginTime, BeginTime: time.Now(),
IsClientStream: false, IsClientStream: false,
IsServerStream: false, IsServerStream: false,
} }
@@ -1282,7 +1280,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
trInfo.tr.Finish() trInfo.tr.Finish()
} }
for _, sh := range shs { if sh != nil {
end := &stats.End{ end := &stats.End{
BeginTime: statsBegin.BeginTime, BeginTime: statsBegin.BeginTime,
EndTime: time.Now(), EndTime: time.Now(),
@@ -1379,7 +1377,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
} }
var payInfo *payloadInfo var payInfo *payloadInfo
if len(shs) != 0 || len(binlogs) != 0 { if sh != nil || len(binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free() defer payInfo.free()
} }
@@ -1405,7 +1403,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
} }
for _, sh := range shs { if sh != nil {
sh.HandleRPC(ctx, &stats.InPayload{ sh.HandleRPC(ctx, &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: v, Payload: v,
@@ -1579,32 +1577,30 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
if channelz.IsOn() { if channelz.IsOn() {
s.incrCallsStarted() s.incrCallsStarted()
} }
shs := s.opts.statsHandlers sh := s.statsHandler
var statsBegin *stats.Begin var statsBegin *stats.Begin
if len(shs) != 0 { if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{ statsBegin = &stats.Begin{
BeginTime: beginTime, BeginTime: time.Now(),
IsClientStream: sd.ClientStreams, IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams, IsServerStream: sd.ServerStreams,
} }
for _, sh := range shs {
sh.HandleRPC(ctx, statsBegin) sh.HandleRPC(ctx, statsBegin)
} }
}
ctx = NewContextWithServerTransportStream(ctx, stream) ctx = NewContextWithServerTransportStream(ctx, stream)
ss := &serverStream{ ss := &serverStream{
ctx: ctx, ctx: ctx,
s: stream, s: stream,
p: &parser{r: stream, bufferPool: s.opts.bufferPool}, p: parser{r: stream, bufferPool: s.opts.bufferPool},
codec: s.getCodec(stream.ContentSubtype()), codec: s.getCodec(stream.ContentSubtype()),
desc: sd,
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo, trInfo: trInfo,
statsHandler: shs, statsHandler: sh,
} }
if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if sh != nil || trInfo != nil || channelz.IsOn() {
// See comment in processUnaryRPC on defers. // See comment in processUnaryRPC on defers.
defer func() { defer func() {
if trInfo != nil { if trInfo != nil {
@@ -1618,7 +1614,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
ss.mu.Unlock() ss.mu.Unlock()
} }
if len(shs) != 0 { if sh != nil {
end := &stats.End{ end := &stats.End{
BeginTime: statsBegin.BeginTime, BeginTime: statsBegin.BeginTime,
EndTime: time.Now(), EndTime: time.Now(),
@@ -1626,10 +1622,8 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
end.Error = toRPCErr(err) end.Error = toRPCErr(err)
} }
for _, sh := range shs {
sh.HandleRPC(ctx, end) sh.HandleRPC(ctx, end)
} }
}
if channelz.IsOn() { if channelz.IsOn() {
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
@@ -1771,6 +1765,24 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
return ss.s.WriteStatus(statusOK) return ss.s.WriteStatus(statusOK)
} }
func (s *Server) handleMalformedMethodName(stream *transport.ServerStream, ti *traceInfo) {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{stream.Method()}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
}
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) { func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) {
ctx := stream.Context() ctx := stream.Context()
ctx = contextWithServer(ctx, s) ctx = contextWithServer(ctx, s)
@@ -1791,37 +1803,40 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Ser
} }
sm := stream.Method() sm := stream.Method()
if sm != "" && sm[0] == '/' { if sm == "" {
s.handleMalformedMethodName(stream, ti)
return
}
if sm[0] != '/' {
// TODO(easwars): Add a link to the CVE in the below log messages once
// published.
if envconfig.DisableStrictPathChecking {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream received malformed method name %q. Allowing it because the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING is set to true, but this option will be removed in a future release.", sm)
}
} else {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream rejected malformed method name %q. To temporarily allow such requests, set the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to true. Note that this is not recommended as it may allow requests to bypass security policies.", sm)
}
s.handleMalformedMethodName(stream, ti)
return
}
} else {
sm = sm[1:] sm = sm[1:]
} }
pos := strings.LastIndex(sm, "/") pos := strings.LastIndex(sm, "/")
if pos == -1 { if pos == -1 {
if ti != nil { s.handleMalformedMethodName(stream, ti)
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
return return
} }
service := sm[:pos] service := sm[:pos]
method := sm[pos+1:] method := sm[pos+1:]
// FromIncomingContext is expensive: skip if there are no statsHandlers // FromIncomingContext is expensive: skip if there are no statsHandlers
if len(s.opts.statsHandlers) > 0 { if s.statsHandler != nil {
md, _ := metadata.FromIncomingContext(ctx) md, _ := metadata.FromIncomingContext(ctx)
for _, sh := range s.opts.statsHandlers { ctx = s.statsHandler.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) s.statsHandler.HandleRPC(ctx, &stats.InHeader{
sh.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(), FullMethod: stream.Method(),
RemoteAddr: t.Peer().Addr, RemoteAddr: t.Peer().Addr,
LocalAddr: t.Peer().LocalAddr, LocalAddr: t.Peer().LocalAddr,
@@ -1830,7 +1845,6 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Ser
Header: md, Header: md,
}) })
} }
}
// To have calls in stream callouts work. Will delete once all stats handler // To have calls in stream callouts work. Will delete once all stats handler
// calls come from the gRPC layer. // calls come from the gRPC layer.
stream.SetContext(ctx) stream.SetContext(ctx)

View File

@@ -64,15 +64,21 @@ func (s *Begin) IsClient() bool { return s.Client }
func (s *Begin) isRPCStats() {} func (s *Begin) isRPCStats() {}
// PickerUpdated indicates that the LB policy provided a new picker while the // DelayedPickComplete indicates that the RPC is unblocked following a delay in
// RPC was waiting for one. // selecting a connection for the call.
type PickerUpdated struct{} type DelayedPickComplete struct{}
// IsClient indicates if the stats information is from client side. Only Client // IsClient indicates DelayedPickComplete is available on the client.
// Side interfaces with a Picker, thus always returns true. func (*DelayedPickComplete) IsClient() bool { return true }
func (*PickerUpdated) IsClient() bool { return true }
func (*PickerUpdated) isRPCStats() {} func (*DelayedPickComplete) isRPCStats() {}
// PickerUpdated indicates that the RPC is unblocked following a delay in
// selecting a connection for the call.
//
// Deprecated: will be removed in a future release; use DelayedPickComplete
// instead.
type PickerUpdated = DelayedPickComplete
// InPayload contains stats about an incoming payload. // InPayload contains stats about an incoming payload.
type InPayload struct { type InPayload struct {

View File

@@ -25,6 +25,7 @@ import (
"math" "math"
rand "math/rand/v2" rand "math/rand/v2"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -51,7 +52,8 @@ import (
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
// StreamHandler defines the handler called by gRPC server to complete the // StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. // execution of a streaming RPC. srv is the service implementation on which the
// RPC was invoked.
// //
// If a StreamHandler returns an error, it should either be produced by the // If a StreamHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use // status package, or be one of the context errors. Otherwise, gRPC will use
@@ -177,13 +179,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return cc.NewStream(ctx, desc, method, opts...) return cc.NewStream(ctx, desc, method, opts...)
} }
var emptyMethodConfig = serviceconfig.MethodConfig{}
// endOfClientStream performs cleanup actions required for both successful and
// failed streams. This includes incrementing channelz stats and invoking all
// registered OnFinish call options.
func endOfClientStream(cc *ClientConn, err error, opts ...CallOption) {
if channelz.IsOn() {
if err != nil {
cc.incrCallsFailed()
} else {
cc.incrCallsSucceeded()
}
}
for _, o := range opts {
if o, ok := o.(OnFinishCallOption); ok {
o.OnFinish(err)
}
}
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if channelz.IsOn() {
cc.incrCallsStarted()
}
defer func() {
if err != nil {
// Ensure cleanup when stream creation fails.
endOfClientStream(cc, err, opts...)
}
}()
// Start tracking the RPC for idleness purposes. This is where a stream is // Start tracking the RPC for idleness purposes. This is where a stream is
// created for both streaming and unary RPCs, and hence is a good place to // created for both streaming and unary RPCs, and hence is a good place to
// track active RPC count. // track active RPC count.
if err := cc.idlenessMgr.OnCallBegin(); err != nil { cc.idlenessMgr.OnCallBegin()
return nil, err
}
// Add a calloption, to decrement the active call count, that gets executed // Add a calloption, to decrement the active call count, that gets executed
// when the RPC completes. // when the RPC completes.
opts = append([]CallOption{OnFinish(func(error) { cc.idlenessMgr.OnCallEnd() })}, opts...) opts = append([]CallOption{OnFinish(func(error) { cc.idlenessMgr.OnCallEnd() })}, opts...)
@@ -202,14 +234,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
} }
} }
if channelz.IsOn() {
cc.incrCallsStarted()
defer func() {
if err != nil {
cc.incrCallsFailed()
}
}()
}
// Provide an opportunity for the first RPC to see the first service config // Provide an opportunity for the first RPC to see the first service config
// provided by the resolver. // provided by the resolver.
nameResolutionDelayed, err := cc.waitForResolvedAddrs(ctx) nameResolutionDelayed, err := cc.waitForResolvedAddrs(ctx)
@@ -217,7 +241,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return nil, err return nil, err
} }
var mc serviceconfig.MethodConfig mc := &emptyMethodConfig
var onCommit func() var onCommit func()
newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) { newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, nameResolutionDelayed, opts...) return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, nameResolutionDelayed, opts...)
@@ -240,7 +264,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if rpcConfig.Context != nil { if rpcConfig.Context != nil {
ctx = rpcConfig.Context ctx = rpcConfig.Context
} }
mc = rpcConfig.MethodConfig mc = &rpcConfig.MethodConfig
onCommit = rpcConfig.OnCommitted onCommit = rpcConfig.OnCommitted
if rpcConfig.Interceptor != nil { if rpcConfig.Interceptor != nil {
rpcInfo.Context = nil rpcInfo.Context = nil
@@ -258,7 +282,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return newStream(ctx, func() {}) return newStream(ctx, func() {})
} }
func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) { func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc *serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) {
callInfo := defaultCallInfo() callInfo := defaultCallInfo()
if mc.WaitForReady != nil { if mc.WaitForReady != nil {
callInfo.failFast = !*mc.WaitForReady callInfo.failFast = !*mc.WaitForReady
@@ -299,6 +323,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
DoneFunc: doneFunc, DoneFunc: doneFunc,
Authority: callInfo.authority, Authority: callInfo.authority,
} }
if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 {
headerValue := strings.Join(allowed, ",")
callHdr.AcceptedCompressors = &headerValue
}
// Set our outgoing compression according to the UseCompressor CallOption, if // Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package. // set. In that case, also find the compressor from the encoding package.
@@ -325,7 +353,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
cs := &clientStream{ cs := &clientStream{
callHdr: callHdr, callHdr: callHdr,
ctx: ctx, ctx: ctx,
methodConfig: &mc, methodConfig: mc,
opts: opts, opts: opts,
callInfo: callInfo, callInfo: callInfo,
cc: cc, cc: cc,
@@ -418,19 +446,21 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.compressorV0, cs.compressorV1) ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.compressorV0, cs.compressorV1)
method := cs.callHdr.Method method := cs.callHdr.Method
var beginTime time.Time var beginTime time.Time
shs := cs.cc.dopts.copts.StatsHandlers sh := cs.cc.statsHandler
for _, sh := range shs { if sh != nil {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast, NameResolutionDelay: cs.nameResolutionDelay})
beginTime = time.Now() beginTime = time.Now()
begin := &stats.Begin{ ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{
FullMethodName: method, FailFast: cs.callInfo.failFast,
NameResolutionDelay: cs.nameResolutionDelay,
})
sh.HandleRPC(ctx, &stats.Begin{
Client: true, Client: true,
BeginTime: beginTime, BeginTime: beginTime,
FailFast: cs.callInfo.failFast, FailFast: cs.callInfo.failFast,
IsClientStream: cs.desc.ClientStreams, IsClientStream: cs.desc.ClientStreams,
IsServerStream: cs.desc.ServerStreams, IsServerStream: cs.desc.ServerStreams,
IsTransparentRetryAttempt: isTransparent, IsTransparentRetryAttempt: isTransparent,
} })
sh.HandleRPC(ctx, begin)
} }
var trInfo *traceInfo var trInfo *traceInfo
@@ -461,7 +491,7 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
beginTime: beginTime, beginTime: beginTime,
cs: cs, cs: cs,
decompressorV0: cs.cc.dopts.dc, decompressorV0: cs.cc.dopts.dc,
statsHandlers: shs, statsHandler: sh,
trInfo: trInfo, trInfo: trInfo,
}, nil }, nil
} }
@@ -469,8 +499,9 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
func (a *csAttempt) getTransport() error { func (a *csAttempt) getTransport() error {
cs := a.cs cs := a.cs
var err error pickInfo := balancer.PickInfo{Ctx: a.ctx, FullMethodName: cs.callHdr.Method}
a.transport, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method) pick, err := cs.cc.pickerWrapper.pick(a.ctx, cs.callInfo.failFast, pickInfo)
a.transport, a.pickResult = pick.transport, pick.result
if err != nil { if err != nil {
if de, ok := err.(dropError); ok { if de, ok := err.(dropError); ok {
err = de.error err = de.error
@@ -479,7 +510,10 @@ func (a *csAttempt) getTransport() error {
return err return err
} }
if a.trInfo != nil { if a.trInfo != nil {
a.trInfo.firstLine.SetRemoteAddr(a.transport.RemoteAddr()) a.trInfo.firstLine.SetRemoteAddr(a.transport.Peer().Addr)
}
if pick.blocked && a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, &stats.DelayedPickComplete{})
} }
return nil return nil
} }
@@ -504,9 +538,17 @@ func (a *csAttempt) newStream() error {
md, _ := metadata.FromOutgoingContext(a.ctx) md, _ := metadata.FromOutgoingContext(a.ctx)
md = metadata.Join(md, a.pickResult.Metadata) md = metadata.Join(md, a.pickResult.Metadata)
a.ctx = metadata.NewOutgoingContext(a.ctx, md) a.ctx = metadata.NewOutgoingContext(a.ctx, md)
}
s, err := a.transport.NewStream(a.ctx, cs.callHdr) // If the `CallAuthority` CallOption is not set, check if the LB picker
// has provided an authority override in the PickResult metadata and
// apply it, as specified in gRFC A81.
if cs.callInfo.authority == "" {
if authMD := a.pickResult.Metadata.Get(":authority"); len(authMD) > 0 {
cs.callHdr.Authority = authMD[0]
}
}
}
s, err := a.transport.NewStream(a.ctx, cs.callHdr, a.statsHandler)
if err != nil { if err != nil {
nse, ok := err.(*transport.NewStreamError) nse, ok := err.(*transport.NewStreamError)
if !ok { if !ok {
@@ -523,7 +565,7 @@ func (a *csAttempt) newStream() error {
} }
a.transportStream = s a.transportStream = s
a.ctx = s.Context() a.ctx = s.Context()
a.parser = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool} a.parser = parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
return nil return nil
} }
@@ -543,6 +585,8 @@ type clientStream struct {
sentLast bool // sent an end stream sentLast bool // sent an end stream
receivedFirstMsg bool // set after the first message is received
methodConfig *MethodConfig methodConfig *MethodConfig
ctx context.Context // the application's context, wrapped by stats/tracing ctx context.Context // the application's context, wrapped by stats/tracing
@@ -593,7 +637,7 @@ type csAttempt struct {
cs *clientStream cs *clientStream
transport transport.ClientTransport transport transport.ClientTransport
transportStream *transport.ClientStream transportStream *transport.ClientStream
parser *parser parser parser
pickResult balancer.PickResult pickResult balancer.PickResult
finished bool finished bool
@@ -607,7 +651,7 @@ type csAttempt struct {
// and cleared when the finish method is called. // and cleared when the finish method is called.
trInfo *traceInfo trInfo *traceInfo
statsHandlers []stats.Handler statsHandler stats.Handler
beginTime time.Time beginTime time.Time
// set for newStream errors that may be transparently retried // set for newStream errors that may be transparently retried
@@ -1032,9 +1076,6 @@ func (cs *clientStream) finish(err error) {
return return
} }
cs.finished = true cs.finished = true
for _, onFinish := range cs.callInfo.onFinish {
onFinish(err)
}
cs.commitAttemptLocked() cs.commitAttemptLocked()
if cs.attempt != nil { if cs.attempt != nil {
cs.attempt.finish(err) cs.attempt.finish(err)
@@ -1074,13 +1115,7 @@ func (cs *clientStream) finish(err error) {
if err == nil { if err == nil {
cs.retryThrottler.successfulRPC() cs.retryThrottler.successfulRPC()
} }
if channelz.IsOn() { endOfClientStream(cs.cc, err, cs.opts...)
if err != nil {
cs.cc.incrCallsFailed()
} else {
cs.cc.incrCallsSucceeded()
}
}
cs.cancel() cs.cancel()
} }
@@ -1102,17 +1137,15 @@ func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength
} }
return io.EOF return io.EOF
} }
if len(a.statsHandlers) != 0 { if a.statsHandler != nil {
for _, sh := range a.statsHandlers { a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
}
} }
return nil return nil
} }
func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
cs := a.cs cs := a.cs
if len(a.statsHandlers) != 0 && payInfo == nil { if a.statsHandler != nil && payInfo == nil {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free() defer payInfo.free()
} }
@@ -1126,6 +1159,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
a.decompressorV0 = nil a.decompressorV0 = nil
a.decompressorV1 = encoding.GetCompressor(ct) a.decompressorV1 = encoding.GetCompressor(ct)
} }
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else { } else {
// No compression is used; disable our decompressor. // No compression is used; disable our decompressor.
a.decompressorV0 = nil a.decompressorV0 = nil
@@ -1133,16 +1170,21 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
a.decompressorSet = true a.decompressorSet = true
} }
if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := a.transportStream.Status().Err(); statusErr != nil { if statusErr := a.transportStream.Status().Err(); statusErr != nil {
return statusErr return statusErr
} }
// Received no msg and status OK for non-server streaming rpcs.
if !cs.desc.ServerStreams && !cs.receivedFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no response message from non-server-streaming RPC")
}
return io.EOF // indicates successful end of stream. return io.EOF // indicates successful end of stream.
} }
return toRPCErr(err) return toRPCErr(err)
} }
cs.receivedFirstMsg = true
if a.trInfo != nil { if a.trInfo != nil {
a.mu.Lock() a.mu.Lock()
if a.trInfo.tr != nil { if a.trInfo.tr != nil {
@@ -1150,8 +1192,8 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} }
a.mu.Unlock() a.mu.Unlock()
} }
for _, sh := range a.statsHandlers { if a.statsHandler != nil {
sh.HandleRPC(a.ctx, &stats.InPayload{ a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{
Client: true, Client: true,
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
@@ -1166,12 +1208,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} }
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF {
return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil { } else if err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message") return status.Error(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
} }
func (a *csAttempt) finish(err error) { func (a *csAttempt) finish(err error) {
@@ -1204,15 +1246,14 @@ func (a *csAttempt) finish(err error) {
ServerLoad: balancerload.Parse(tr), ServerLoad: balancerload.Parse(tr),
}) })
} }
for _, sh := range a.statsHandlers { if a.statsHandler != nil {
end := &stats.End{ a.statsHandler.HandleRPC(a.ctx, &stats.End{
Client: true, Client: true,
BeginTime: a.beginTime, BeginTime: a.beginTime,
EndTime: time.Now(), EndTime: time.Now(),
Trailer: tr, Trailer: tr,
Error: err, Error: err,
} })
sh.HandleRPC(a.ctx, end)
} }
if a.trInfo != nil && a.trInfo.tr != nil { if a.trInfo != nil && a.trInfo.tr != nil {
if err == nil { if err == nil {
@@ -1309,16 +1350,18 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
codec: c.codec, codec: c.codec,
sendCompressorV0: cp, sendCompressorV0: cp,
sendCompressorV1: comp, sendCompressorV1: comp,
decompressorV0: ac.cc.dopts.dc,
transport: t, transport: t,
} }
s, err := as.transport.NewStream(as.ctx, as.callHdr) // nil stats handler: internal streams like health and ORCA do not support telemetry.
s, err := as.transport.NewStream(as.ctx, as.callHdr, nil)
if err != nil { if err != nil {
err = toRPCErr(err) err = toRPCErr(err)
return nil, err return nil, err
} }
as.transportStream = s as.transportStream = s
as.parser = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool} as.parser = parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
ac.incrCallsStarted() ac.incrCallsStarted()
if desc != unaryStreamDesc { if desc != unaryStreamDesc {
// Listen on stream context to cleanup when the stream context is // Listen on stream context to cleanup when the stream context is
@@ -1353,6 +1396,7 @@ type addrConnStream struct {
transport transport.ClientTransport transport transport.ClientTransport
ctx context.Context ctx context.Context
sentLast bool sentLast bool
receivedFirstMsg bool
desc *StreamDesc desc *StreamDesc
codec baseCodec codec baseCodec
sendCompressorV0 Compressor sendCompressorV0 Compressor
@@ -1360,7 +1404,7 @@ type addrConnStream struct {
decompressorSet bool decompressorSet bool
decompressorV0 Decompressor decompressorV0 Decompressor
decompressorV1 encoding.Compressor decompressorV1 encoding.Compressor
parser *parser parser parser
// mu guards finished and is held for the entire finish method. // mu guards finished and is held for the entire finish method.
mu sync.Mutex mu sync.Mutex
@@ -1466,6 +1510,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
as.decompressorV0 = nil as.decompressorV0 = nil
as.decompressorV1 = encoding.GetCompressor(ct) as.decompressorV1 = encoding.GetCompressor(ct)
} }
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else { } else {
// No compression is used; disable our decompressor. // No compression is used; disable our decompressor.
as.decompressorV0 = nil as.decompressorV0 = nil
@@ -1473,15 +1521,20 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
as.decompressorSet = true as.decompressorSet = true
} }
if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := as.transportStream.Status().Err(); statusErr != nil { if statusErr := as.transportStream.Status().Err(); statusErr != nil {
return statusErr return statusErr
} }
// Received no msg and status OK for non-server streaming rpcs.
if !as.desc.ServerStreams && !as.receivedFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no response message from non-server-streaming RPC")
}
return io.EOF // indicates successful end of stream. return io.EOF // indicates successful end of stream.
} }
return toRPCErr(err) return toRPCErr(err)
} }
as.receivedFirstMsg = true
if as.desc.ServerStreams { if as.desc.ServerStreams {
// Subsequent messages should be received by subsequent RecvMsg calls. // Subsequent messages should be received by subsequent RecvMsg calls.
@@ -1490,12 +1543,12 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF {
return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil { } else if err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message") return status.Error(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
} }
func (as *addrConnStream) finish(err error) { func (as *addrConnStream) finish(err error) {
@@ -1578,8 +1631,9 @@ type ServerStream interface {
type serverStream struct { type serverStream struct {
ctx context.Context ctx context.Context
s *transport.ServerStream s *transport.ServerStream
p *parser p parser
codec baseCodec codec baseCodec
desc *StreamDesc
compressorV0 Compressor compressorV0 Compressor
compressorV1 encoding.Compressor compressorV1 encoding.Compressor
@@ -1588,11 +1642,13 @@ type serverStream struct {
sendCompressorName string sendCompressorName string
recvFirstMsg bool // set after the first message is received
maxReceiveMessageSize int maxReceiveMessageSize int
maxSendMessageSize int maxSendMessageSize int
trInfo *traceInfo trInfo *traceInfo
statsHandler []stats.Handler statsHandler stats.Handler
binlogs []binarylog.MethodLogger binlogs []binarylog.MethodLogger
// serverHeaderBinlogged indicates whether server header has been logged. It // serverHeaderBinlogged indicates whether server header has been logged. It
@@ -1728,10 +1784,8 @@ func (ss *serverStream) SendMsg(m any) (err error) {
binlog.Log(ss.ctx, sm) binlog.Log(ss.ctx, sm)
} }
} }
if len(ss.statsHandler) != 0 { if ss.statsHandler != nil {
for _, sh := range ss.statsHandler { ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
}
} }
return nil return nil
} }
@@ -1762,11 +1816,11 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
} }
}() }()
var payInfo *payloadInfo var payInfo *payloadInfo
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { if ss.statsHandler != nil || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free() defer payInfo.free()
} }
if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil {
if err == io.EOF { if err == io.EOF {
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{} chc := &binarylog.ClientHalfClose{}
@@ -1774,6 +1828,10 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, chc) binlog.Log(ss.ctx, chc)
} }
} }
// Received no request msg for non-client streaming rpcs.
if !ss.desc.ClientStreams && !ss.recvFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-streaming RPC")
}
return err return err
} }
if err == io.ErrUnexpectedEOF { if err == io.ErrUnexpectedEOF {
@@ -1781,9 +1839,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
} }
return toRPCErr(err) return toRPCErr(err)
} }
if len(ss.statsHandler) != 0 { ss.recvFirstMsg = true
for _, sh := range ss.statsHandler { if ss.statsHandler != nil {
sh.HandleRPC(ss.s.Context(), &stats.InPayload{ ss.statsHandler.HandleRPC(ss.s.Context(), &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
Length: payInfo.uncompressedBytes.Len(), Length: payInfo.uncompressedBytes.Len(),
@@ -1791,7 +1849,6 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
}) })
} }
}
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
Message: payInfo.uncompressedBytes.Materialize(), Message: payInfo.uncompressedBytes.Materialize(),
@@ -1800,8 +1857,20 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, cm) binlog.Log(ss.ctx, cm)
} }
} }
if ss.desc.ClientStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
return nil return nil
} }
// Special handling for non-client-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF {
return nil
} else if err != nil {
return err
}
return status.Error(codes.Internal, "cardinality violation: received multiple request messages for non-client-streaming RPC")
}
// MethodFromServerStream returns the method string for the input stream. // MethodFromServerStream returns the method string for the input stream.
// The returned string is in the format of "/service/method". // The returned string is in the format of "/service/method".

View File

@@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.74.2" const Version = "1.79.3"

View File

@@ -74,6 +74,7 @@ type (
FileL2 struct { FileL2 struct {
Options func() protoreflect.ProtoMessage Options func() protoreflect.ProtoMessage
Imports FileImports Imports FileImports
OptionImports func() protoreflect.FileImports
Locations SourceLocations Locations SourceLocations
} }
@@ -126,9 +127,6 @@ func (fd *File) ParentFile() protoreflect.FileDescriptor { return fd }
func (fd *File) Parent() protoreflect.Descriptor { return nil } func (fd *File) Parent() protoreflect.Descriptor { return nil }
func (fd *File) Index() int { return 0 } func (fd *File) Index() int { return 0 }
func (fd *File) Syntax() protoreflect.Syntax { return fd.L1.Syntax } func (fd *File) Syntax() protoreflect.Syntax { return fd.L1.Syntax }
// Not exported and just used to reconstruct the original FileDescriptor proto
func (fd *File) Edition() int32 { return int32(fd.L1.Edition) }
func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() } func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() }
func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package } func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package }
func (fd *File) IsPlaceholder() bool { return false } func (fd *File) IsPlaceholder() bool { return false }
@@ -150,6 +148,16 @@ func (fd *File) Format(s fmt.State, r rune) { descfmt.FormatD
func (fd *File) ProtoType(protoreflect.FileDescriptor) {} func (fd *File) ProtoType(protoreflect.FileDescriptor) {}
func (fd *File) ProtoInternal(pragma.DoNotImplement) {} func (fd *File) ProtoInternal(pragma.DoNotImplement) {}
// The next two are not part of the FileDescriptor interface. They are just used to reconstruct
// the original FileDescriptor proto.
func (fd *File) Edition() int32 { return int32(fd.L1.Edition) }
func (fd *File) OptionImports() protoreflect.FileImports {
if f := fd.lazyInit().OptionImports; f != nil {
return f()
}
return emptyFiles
}
func (fd *File) lazyInit() *FileL2 { func (fd *File) lazyInit() *FileL2 {
if atomic.LoadUint32(&fd.once) == 0 { if atomic.LoadUint32(&fd.once) == 0 {
fd.lazyInitOnce() fd.lazyInitOnce()
@@ -182,9 +190,9 @@ type (
L2 *EnumL2 // protected by fileDesc.once L2 *EnumL2 // protected by fileDesc.once
} }
EnumL1 struct { EnumL1 struct {
eagerValues bool // controls whether EnumL2.Values is already populated
EditionFeatures EditionFeatures EditionFeatures EditionFeatures
Visibility int32
eagerValues bool // controls whether EnumL2.Values is already populated
} }
EnumL2 struct { EnumL2 struct {
Options func() protoreflect.ProtoMessage Options func() protoreflect.ProtoMessage
@@ -219,6 +227,11 @@ func (ed *Enum) ReservedNames() protoreflect.Names { return &ed.lazyInit()
func (ed *Enum) ReservedRanges() protoreflect.EnumRanges { return &ed.lazyInit().ReservedRanges } func (ed *Enum) ReservedRanges() protoreflect.EnumRanges { return &ed.lazyInit().ReservedRanges }
func (ed *Enum) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, ed) } func (ed *Enum) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, ed) }
func (ed *Enum) ProtoType(protoreflect.EnumDescriptor) {} func (ed *Enum) ProtoType(protoreflect.EnumDescriptor) {}
// This is not part of the EnumDescriptor interface. It is just used to reconstruct
// the original FileDescriptor proto.
func (ed *Enum) Visibility() int32 { return ed.L1.Visibility }
func (ed *Enum) lazyInit() *EnumL2 { func (ed *Enum) lazyInit() *EnumL2 {
ed.L0.ParentFile.lazyInit() // implicitly initializes L2 ed.L0.ParentFile.lazyInit() // implicitly initializes L2
return ed.L2 return ed.L2
@@ -247,10 +260,10 @@ type (
Enums Enums Enums Enums
Messages Messages Messages Messages
Extensions Extensions Extensions Extensions
EditionFeatures EditionFeatures
Visibility int32
IsMapEntry bool // promoted from google.protobuf.MessageOptions IsMapEntry bool // promoted from google.protobuf.MessageOptions
IsMessageSet bool // promoted from google.protobuf.MessageOptions IsMessageSet bool // promoted from google.protobuf.MessageOptions
EditionFeatures EditionFeatures
} }
MessageL2 struct { MessageL2 struct {
Options func() protoreflect.ProtoMessage Options func() protoreflect.ProtoMessage
@@ -319,6 +332,11 @@ func (md *Message) Messages() protoreflect.MessageDescriptors { return &md.L
func (md *Message) Extensions() protoreflect.ExtensionDescriptors { return &md.L1.Extensions } func (md *Message) Extensions() protoreflect.ExtensionDescriptors { return &md.L1.Extensions }
func (md *Message) ProtoType(protoreflect.MessageDescriptor) {} func (md *Message) ProtoType(protoreflect.MessageDescriptor) {}
func (md *Message) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, md) } func (md *Message) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, md) }
// This is not part of the MessageDescriptor interface. It is just used to reconstruct
// the original FileDescriptor proto.
func (md *Message) Visibility() int32 { return md.L1.Visibility }
func (md *Message) lazyInit() *MessageL2 { func (md *Message) lazyInit() *MessageL2 {
md.L0.ParentFile.lazyInit() // implicitly initializes L2 md.L0.ParentFile.lazyInit() // implicitly initializes L2
return md.L2 return md.L2

View File

@@ -284,6 +284,13 @@ func (ed *Enum) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protorefl
case genid.EnumDescriptorProto_Value_field_number: case genid.EnumDescriptorProto_Value_field_number:
numValues++ numValues++
} }
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.EnumDescriptorProto_Visibility_field_number:
ed.L1.Visibility = int32(v)
}
default: default:
m := protowire.ConsumeFieldValue(num, typ, b) m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:] b = b[m:]
@@ -365,6 +372,13 @@ func (md *Message) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protor
md.unmarshalSeedOptions(v) md.unmarshalSeedOptions(v)
} }
prevField = num prevField = num
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.DescriptorProto_Visibility_field_number:
md.L1.Visibility = int32(v)
}
default: default:
m := protowire.ConsumeFieldValue(num, typ, b) m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:] b = b[m:]

View File

@@ -134,6 +134,7 @@ func (fd *File) unmarshalFull(b []byte) {
var enumIdx, messageIdx, extensionIdx, serviceIdx int var enumIdx, messageIdx, extensionIdx, serviceIdx int
var rawOptions []byte var rawOptions []byte
var optionImports []string
fd.L2 = new(FileL2) fd.L2 = new(FileL2)
for len(b) > 0 { for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b) num, typ, n := protowire.ConsumeTag(b)
@@ -157,6 +158,8 @@ func (fd *File) unmarshalFull(b []byte) {
imp = PlaceholderFile(path) imp = PlaceholderFile(path)
} }
fd.L2.Imports = append(fd.L2.Imports, protoreflect.FileImport{FileDescriptor: imp}) fd.L2.Imports = append(fd.L2.Imports, protoreflect.FileImport{FileDescriptor: imp})
case genid.FileDescriptorProto_OptionDependency_field_number:
optionImports = append(optionImports, sb.MakeString(v))
case genid.FileDescriptorProto_EnumType_field_number: case genid.FileDescriptorProto_EnumType_field_number:
fd.L1.Enums.List[enumIdx].unmarshalFull(v, sb) fd.L1.Enums.List[enumIdx].unmarshalFull(v, sb)
enumIdx++ enumIdx++
@@ -178,6 +181,23 @@ func (fd *File) unmarshalFull(b []byte) {
} }
} }
fd.L2.Options = fd.builder.optionsUnmarshaler(&descopts.File, rawOptions) fd.L2.Options = fd.builder.optionsUnmarshaler(&descopts.File, rawOptions)
if len(optionImports) > 0 {
var imps FileImports
var once sync.Once
fd.L2.OptionImports = func() protoreflect.FileImports {
once.Do(func() {
imps = make(FileImports, len(optionImports))
for i, path := range optionImports {
imp, _ := fd.builder.FileRegistry.FindFileByPath(path)
if imp == nil {
imp = PlaceholderFile(path)
}
imps[i] = protoreflect.FileImport{FileDescriptor: imp}
}
})
return &imps
}
}
} }
func (ed *Enum) unmarshalFull(b []byte, sb *strs.Builder) { func (ed *Enum) unmarshalFull(b []byte, sb *strs.Builder) {

View File

@@ -13,8 +13,10 @@ import (
"google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoreflect"
) )
var defaultsCache = make(map[Edition]EditionFeatures) var (
var defaultsKeys = []Edition{} defaultsCache = make(map[Edition]EditionFeatures)
defaultsKeys = []Edition{}
)
func init() { func init() {
unmarshalEditionDefaults(editiondefaults.Defaults) unmarshalEditionDefaults(editiondefaults.Defaults)
@@ -41,7 +43,7 @@ func unmarshalGoFeature(b []byte, parent EditionFeatures) EditionFeatures {
b = b[m:] b = b[m:]
parent.StripEnumPrefix = int(v) parent.StripEnumPrefix = int(v)
default: default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling GoFeatures", num)) panic(fmt.Sprintf("unknown field number %d while unmarshalling GoFeatures", num))
} }
} }
return parent return parent
@@ -76,7 +78,7 @@ func unmarshalFeatureSet(b []byte, parent EditionFeatures) EditionFeatures {
// DefaultSymbolVisibility is enforced in protoc, runtimes should not // DefaultSymbolVisibility is enforced in protoc, runtimes should not
// inspect this value. // inspect this value.
default: default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling FeatureSet", num)) panic(fmt.Sprintf("unknown field number %d while unmarshalling FeatureSet", num))
} }
case protowire.BytesType: case protowire.BytesType:
v, m := protowire.ConsumeBytes(b) v, m := protowire.ConsumeBytes(b)
@@ -150,7 +152,7 @@ func unmarshalEditionDefaults(b []byte) {
_, m := protowire.ConsumeVarint(b) _, m := protowire.ConsumeVarint(b)
b = b[m:] b = b[m:]
default: default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling EditionDefault", num)) panic(fmt.Sprintf("unknown field number %d while unmarshalling EditionDefault", num))
} }
} }
} }

View File

@@ -27,6 +27,7 @@ const (
Api_SourceContext_field_name protoreflect.Name = "source_context" Api_SourceContext_field_name protoreflect.Name = "source_context"
Api_Mixins_field_name protoreflect.Name = "mixins" Api_Mixins_field_name protoreflect.Name = "mixins"
Api_Syntax_field_name protoreflect.Name = "syntax" Api_Syntax_field_name protoreflect.Name = "syntax"
Api_Edition_field_name protoreflect.Name = "edition"
Api_Name_field_fullname protoreflect.FullName = "google.protobuf.Api.name" Api_Name_field_fullname protoreflect.FullName = "google.protobuf.Api.name"
Api_Methods_field_fullname protoreflect.FullName = "google.protobuf.Api.methods" Api_Methods_field_fullname protoreflect.FullName = "google.protobuf.Api.methods"
@@ -35,6 +36,7 @@ const (
Api_SourceContext_field_fullname protoreflect.FullName = "google.protobuf.Api.source_context" Api_SourceContext_field_fullname protoreflect.FullName = "google.protobuf.Api.source_context"
Api_Mixins_field_fullname protoreflect.FullName = "google.protobuf.Api.mixins" Api_Mixins_field_fullname protoreflect.FullName = "google.protobuf.Api.mixins"
Api_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Api.syntax" Api_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Api.syntax"
Api_Edition_field_fullname protoreflect.FullName = "google.protobuf.Api.edition"
) )
// Field numbers for google.protobuf.Api. // Field numbers for google.protobuf.Api.
@@ -46,6 +48,7 @@ const (
Api_SourceContext_field_number protoreflect.FieldNumber = 5 Api_SourceContext_field_number protoreflect.FieldNumber = 5
Api_Mixins_field_number protoreflect.FieldNumber = 6 Api_Mixins_field_number protoreflect.FieldNumber = 6
Api_Syntax_field_number protoreflect.FieldNumber = 7 Api_Syntax_field_number protoreflect.FieldNumber = 7
Api_Edition_field_number protoreflect.FieldNumber = 8
) )
// Names for google.protobuf.Method. // Names for google.protobuf.Method.
@@ -63,6 +66,7 @@ const (
Method_ResponseStreaming_field_name protoreflect.Name = "response_streaming" Method_ResponseStreaming_field_name protoreflect.Name = "response_streaming"
Method_Options_field_name protoreflect.Name = "options" Method_Options_field_name protoreflect.Name = "options"
Method_Syntax_field_name protoreflect.Name = "syntax" Method_Syntax_field_name protoreflect.Name = "syntax"
Method_Edition_field_name protoreflect.Name = "edition"
Method_Name_field_fullname protoreflect.FullName = "google.protobuf.Method.name" Method_Name_field_fullname protoreflect.FullName = "google.protobuf.Method.name"
Method_RequestTypeUrl_field_fullname protoreflect.FullName = "google.protobuf.Method.request_type_url" Method_RequestTypeUrl_field_fullname protoreflect.FullName = "google.protobuf.Method.request_type_url"
@@ -71,6 +75,7 @@ const (
Method_ResponseStreaming_field_fullname protoreflect.FullName = "google.protobuf.Method.response_streaming" Method_ResponseStreaming_field_fullname protoreflect.FullName = "google.protobuf.Method.response_streaming"
Method_Options_field_fullname protoreflect.FullName = "google.protobuf.Method.options" Method_Options_field_fullname protoreflect.FullName = "google.protobuf.Method.options"
Method_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Method.syntax" Method_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Method.syntax"
Method_Edition_field_fullname protoreflect.FullName = "google.protobuf.Method.edition"
) )
// Field numbers for google.protobuf.Method. // Field numbers for google.protobuf.Method.
@@ -82,6 +87,7 @@ const (
Method_ResponseStreaming_field_number protoreflect.FieldNumber = 5 Method_ResponseStreaming_field_number protoreflect.FieldNumber = 5
Method_Options_field_number protoreflect.FieldNumber = 6 Method_Options_field_number protoreflect.FieldNumber = 6
Method_Syntax_field_number protoreflect.FieldNumber = 7 Method_Syntax_field_number protoreflect.FieldNumber = 7
Method_Edition_field_number protoreflect.FieldNumber = 8
) )
// Names for google.protobuf.Mixin. // Names for google.protobuf.Mixin.

View File

@@ -52,7 +52,7 @@ import (
const ( const (
Major = 1 Major = 1
Minor = 36 Minor = 36
Patch = 7 Patch = 10
PreRelease = "" PreRelease = ""
) )

1
vendor/modernc.org/libc/AUTHORS generated vendored
View File

@@ -17,6 +17,7 @@ Jason DeBettencourt <jasond17@gmail.com>
Jasper Siepkes <jasper@siepkes.nl> Jasper Siepkes <jasper@siepkes.nl>
Koichi Shiraishi <zchee.io@gmail.com> Koichi Shiraishi <zchee.io@gmail.com>
Marius Orcsik <marius@federated.id> Marius Orcsik <marius@federated.id>
Olivier Mengué <dolmen@cpan.org>
Patricio Whittingslow <graded.sp@gmail.com> Patricio Whittingslow <graded.sp@gmail.com>
Scot C Bontrager <scot@indievisible.org> Scot C Bontrager <scot@indievisible.org>
Steffen Butzer <steffen(dot)butzer@outlook.com> Steffen Butzer <steffen(dot)butzer@outlook.com>

31
vendor/modernc.org/libc/CONTRIBUTING.md generated vendored Normal file
View File

@@ -0,0 +1,31 @@
# Contributing to this repository
Thank you for your interest in contributing! To help keep the project stable across its many targets, please follow these guidelines when submitting a pull request or merge request.
### Verification
Before submitting your changes, please ensure that they do not break the build for different architectures or build tags.
Run the following script in your local environment:
```bash
$ ./build_all_targets.sh
```
Please verify that all targets you can test pass before opening your request.
### Authors and Contributors
If you would like yourself and/or your company to be officially recognized in the project:
* Optionally, please include a change to the AUTHORS and/or CONTRIBUTORS files within your merge request.
### The Process
* Fork the repository (or host a public branch if you do not have a gitlab.com account).
* Implement your changes, keeping them as focused as possible.
* Submit your request with a clear description of the problem solved, the dependency improved, etc.
----
We appreciate your help in making the Go ecosystem more robust!

View File

@@ -18,6 +18,7 @@ Jasper Siepkes <jasper@siepkes.nl>
Koichi Shiraishi <zchee.io@gmail.com> Koichi Shiraishi <zchee.io@gmail.com>
Leonardo Taccari <leot@NetBSD.org> Leonardo Taccari <leot@NetBSD.org>
Marius Orcsik <marius@federated.id> Marius Orcsik <marius@federated.id>
Olivier Mengué <dolmen@cpan.org>
Patricio Whittingslow <graded.sp@gmail.com> Patricio Whittingslow <graded.sp@gmail.com>
Roman Khafizianov <roman@any.org> Roman Khafizianov <roman@any.org>
Scot C Bontrager <scot@indievisible.org> Scot C Bontrager <scot@indievisible.org>

Some files were not shown because too many files have changed in this diff Show More