Use mcdsl/terminal for all password prompts

Replace direct golang.org/x/term calls with mcdsl/terminal.ReadPassword
across mciasctl (6 sites), mciasgrpcctl (1 site), and mciasdb (1 site).
Aligns with the new CLI security standard in engineering-standards.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-28 11:40:11 -07:00
parent e4220b840e
commit 5b5e1a7ed6
142 changed files with 10241 additions and 7788 deletions

View File

@@ -33,17 +33,21 @@ guidelines, there may be valid reasons to do so, but it should be rare.
## Guidelines for Pull Requests
How to get your contributions merged smoothly and quickly:
Please read the following carefully to ensure your contributions can be merged
smoothly and quickly.
### PR Contents
- Create **small PRs** that are narrowly focused on **addressing a single
concern**. We often receive PRs that attempt to fix several things at the same
time, and if one part of the PR has a problem, that will hold up the entire
PR.
- For **speculative changes**, consider opening an issue and discussing it
first. If you are suggesting a behavioral or API change, consider starting
with a [gRFC proposal](https://github.com/grpc/proposal). Many new features
that are not bug fixes will require cross-language agreement.
- If your change does not address an **open issue** with an **agreed
resolution**, consider opening an issue and discussing it first. If you are
suggesting a behavioral or API change, consider starting with a [gRFC
proposal](https://github.com/grpc/proposal). Many new features that are not
bug fixes will require cross-language agreement.
- If you want to fix **formatting or style**, consider whether your changes are
an obvious improvement or might be considered a personal preference. If a
@@ -56,16 +60,6 @@ How to get your contributions merged smoothly and quickly:
often written as "iff". Please do not make spelling correction changes unless
you are certain they are misspellings.
- Provide a good **PR description** as a record of **what** change is being made
and **why** it was made. Link to a GitHub issue if it exists.
- Maintain a **clean commit history** and use **meaningful commit messages**.
PRs with messy commit histories are difficult to review and won't be merged.
Before sending your PR, ensure your changes are based on top of the latest
`upstream/master` commits, and avoid rebasing in the middle of a code review.
You should **never use `git push -f`** unless absolutely necessary during a
review, as it can interfere with GitHub's tracking of comments.
- **All tests need to be passing** before your change can be merged. We
recommend you run tests locally before creating your PR to catch breakages
early on:
@@ -81,15 +75,80 @@ How to get your contributions merged smoothly and quickly:
GitHub, which will trigger a GitHub Actions run that you can use to verify
everything is passing.
- If you are adding a new file, make sure it has the **copyright message**
- Note that there are two GitHub actions checks that need not be green:
1. We test the freshness of the generated proto code we maintain via the
`vet-proto` check. If the source proto files are updated, but our repo is
not updated, an optional checker will fail. This will be fixed by our team
in a separate PR and will not prevent the merge of your PR.
2. We run a checker that will fail if there is any change in dependencies of
an exported package via the `dependencies` check. If new dependencies are
added that are not appropriate, we may not accept your PR (see below).
- If you are adding a **new file**, make sure it has the **copyright message**
template at the top as a comment. You can copy the message from an existing
file and update the year.
- The grpc package should only depend on standard Go packages and a small number
of exceptions. **If your contribution introduces new dependencies**, you will
need a discussion with gRPC-Go maintainers. A GitHub action check will run on
every PR, and will flag any transitive dependency changes from any public
package.
need a discussion with gRPC-Go maintainers.
### PR Descriptions
- **PR titles** should start with the name of the component being addressed, or
the type of change. Examples: transport, client, server, round_robin, xds,
cleanup, deps.
- Read and follow the **guidelines for PR titles and descriptions** here:
https://google.github.io/eng-practices/review/developer/cl-descriptions.html
*particularly* the sections "First Line" and "Body is Informative".
Note: your PR description will be used as the git commit message in a
squash-and-merge if your PR is approved. We may make changes to this as
necessary.
- **Does this PR relate to an open issue?** On the first line, please use the
tag `Fixes #<issue>` to ensure the issue is closed when the PR is merged. Or
use `Updates #<issue>` if the PR is related to an open issue, but does not fix
it. Consider filing an issue if one does not already exist.
- PR descriptions *must* conclude with **release notes** as follows:
```
RELEASE NOTES:
* <component>: <summary>
```
This need not match the PR title.
The summary must:
* be something that gRPC users will understand.
* clearly explain the feature being added, the issue being fixed, or the
behavior being changed, etc. If fixing a bug, be clear about how the bug
can be triggered by an end-user.
* begin with a capital letter and use complete sentences.
* be as short as possible to describe the change being made.
If a PR is *not* end-user visible -- e.g. a cleanup, testing change, or
GitHub-related, use `RELEASE NOTES: n/a`.
### PR Process
- Please **self-review** your code changes before sending your PR. This will
prevent simple, obvious errors from causing delays.
- Maintain a **clean commit history** and use **meaningful commit messages**.
PRs with messy commit histories are difficult to review and won't be merged.
Before sending your PR, ensure your changes are based on top of the latest
`upstream/master` commits, and avoid rebasing in the middle of a code review.
You should **never use `git push -f`** unless absolutely necessary during a
review, as it can interfere with GitHub's tracking of comments.
- Unless your PR is trivial, you should **expect reviewer comments** that you
will need to address before merging. We'll label the PR as `Status: Requires
@@ -98,5 +157,3 @@ How to get your contributions merged smoothly and quickly:
`stale`, and we will automatically close it after 7 days if we don't hear back
from you. Please feel free to ping issues or bugs if you do not get a response
within a week.
- Exceptions to the rules can be made if there's a compelling reason to do so.

View File

@@ -9,21 +9,19 @@ for general contribution guidelines.
## Maintainers (in alphabetical order)
- [aranjans](https://github.com/aranjans), Google LLC
- [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc.
- [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC
- [erm-g](https://github.com/erm-g), Google LLC
- [gtcooke94](https://github.com/gtcooke94), Google LLC
- [purnesh42h](https://github.com/purnesh42h), Google LLC
- [zasweq](https://github.com/zasweq), Google LLC
## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez)
- [aranjans](https://github.com/aranjans)
- [canguler](https://github.com/canguler)
- [cesarghali](https://github.com/cesarghali)
- [erm-g](https://github.com/erm-g)
- [iamqizhao](https://github.com/iamqizhao)
- [jeanbza](https://github.com/jeanbza)
- [jtattermusch](https://github.com/jtattermusch)
@@ -32,5 +30,7 @@ for general contribution guidelines.
- [matt-kwong](https://github.com/matt-kwong)
- [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble)
- [purnesh42h](https://github.com/purnesh42h)
- [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni)
- [zasweq](https://github.com/zasweq)

View File

@@ -75,8 +75,6 @@ func unregisterForTesting(name string) {
func init() {
internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
}
// Get returns the resolver builder registered with the given name.

View File

@@ -37,6 +37,8 @@ import (
"google.golang.org/grpc/resolver"
)
var randIntN = rand.IntN
// ChildState is the balancer state of a child along with the endpoint which
// identifies the child balancer.
type ChildState struct {
@@ -112,6 +114,21 @@ type endpointSharding struct {
mu sync.Mutex
}
// rotateEndpoints returns a slice of all the input endpoints rotated a random
// amount.
func rotateEndpoints(es []resolver.Endpoint) []resolver.Endpoint {
les := len(es)
if les == 0 {
return es
}
r := randIntN(les)
// Make a copy to avoid mutating data beyond the end of es.
ret := make([]resolver.Endpoint, les)
copy(ret, es[r:])
copy(ret[les-r:], es[:r])
return ret
}
// UpdateClientConnState creates a child for new endpoints and deletes children
// for endpoints that are no longer present. It also updates all the children,
// and sends a single synchronous update of the childrens' aggregated state at
@@ -133,7 +150,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
newChildren := resolver.NewEndpointMap[*balancerWrapper]()
// Update/Create new children.
for _, endpoint := range state.ResolverState.Endpoints {
for _, endpoint := range rotateEndpoints(state.ResolverState.Endpoints) {
if _, ok := newChildren.Get(endpoint); ok {
// Endpoint child was already created, continue to avoid duplicate
// update.
@@ -279,7 +296,7 @@ func (es *endpointSharding) updateState() {
p := &pickerWithChildStates{
pickers: pickers,
childStates: childStates,
next: uint32(rand.IntN(len(pickers))),
next: uint32(randIntN(len(pickers))),
}
es.cc.UpdateState(balancer.State{
ConnectivityState: aggState,

View File

@@ -26,6 +26,8 @@ import (
var (
// RandShuffle pseudo-randomizes the order of addresses.
RandShuffle = rand.Shuffle
// RandFloat64 returns, as a float64, a pseudo-random number in [0.0,1.0).
RandFloat64 = rand.Float64
// TimeAfterFunc allows mocking the timer for testing connection delay
// related functionality.
TimeAfterFunc = func(d time.Duration, f func()) func() {

File diff suppressed because it is too large Load Diff

View File

@@ -1,906 +0,0 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pickfirstleaf contains the pick_first load balancing policy which
// will be the universal leaf policy after dualstack changes are implemented.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package pickfirstleaf
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func init() {
if envconfig.NewPickFirstEnabled {
// Register as the default pick_first balancer.
Name = "pick_first"
}
balancer.Register(pickfirstBuilder{})
}
// enableHealthListenerKeyType is a unique key type used in resolver
// attributes to indicate whether the health listener usage is enabled.
type enableHealthListenerKeyType struct{}
var (
logger = grpclog.Component("pick-first-leaf-lb")
// Name is the name of the pick_first_leaf balancer.
// It is changed to "pick_first" in init() if this balancer is to be
// registered as the default pickfirst.
Name = "pick_first_leaf"
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "disconnection",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.lb.pick_first.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "attempt",
Labels: []string{"grpc.target"},
Default: false,
})
)
const (
// TODO: change to pick-first when this becomes the default pick_first policy.
logPrefix = "[pick-first-leaf-lb %p] "
// connectionDelayInterval is the time to wait for during the happy eyeballs
// pass before starting the next connection attempt.
connectionDelayInterval = 250 * time.Millisecond
)
type ipAddrFamily int
const (
// ipAddrFamilyUnknown represents strings that can't be parsed as an IP
// address.
ipAddrFamilyUnknown ipAddrFamily = iota
ipAddrFamilyV4
ipAddrFamilyV6
)
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{
cc: cc,
target: bo.Target.String(),
metricsRecorder: cc.MetricsRecorder(),
subConns: resolver.NewAddressMapV2[*scData](),
state: connectivity.Connecting,
cancelConnectionTimer: func() {},
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (b pickfirstBuilder) Name() string {
return Name
}
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
return cfg, nil
}
// EnableHealthListener updates the state to configure pickfirst for using a
// generic health listener.
func EnableHealthListener(state resolver.State) resolver.State {
state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true)
return state
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
// scData keeps track of the current state of the subConn.
// It is not safe for concurrent access.
type scData struct {
// The following fields are initialized at build time and read-only after
// that.
subConn balancer.SubConn
addr resolver.Address
rawConnectivityState connectivity.State
// The effective connectivity state based on raw connectivity, health state
// and after following sticky TransientFailure behaviour defined in A62.
effectiveState connectivity.State
lastErr error
connectionFailedInFirstPass bool
}
func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) {
sd := &scData{
rawConnectivityState: connectivity.Idle,
effectiveState: connectivity.Idle,
addr: addr,
}
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
StateListener: func(state balancer.SubConnState) {
b.updateSubConnState(sd, state)
},
})
if err != nil {
return nil, err
}
sd.subConn = sc
return sd, nil
}
type pickfirstBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
logger *internalgrpclog.PrefixLogger
cc balancer.ClientConn
target string
metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil
// The mutex is used to ensure synchronization of updates triggered
// from the idle picker and the already serialized resolver,
// SubConn state updates.
mu sync.Mutex
// State reported to the channel based on SubConn states and resolver
// updates.
state connectivity.State
// scData for active subonns mapped by address.
subConns *resolver.AddressMapV2[*scData]
addressList addressList
firstPass bool
numTF int
cancelConnectionTimer func()
healthCheckingEnabled bool
}
// ResolverError is called by the ClientConn when the name resolver produces
// an error or when pickfirst determined the resolver update to be invalid.
func (b *pickfirstBalancer) ResolverError(err error) {
b.mu.Lock()
defer b.mu.Unlock()
b.resolverErrorLocked(err)
}
func (b *pickfirstBalancer) resolverErrorLocked(err error) {
if b.logger.V(2) {
b.logger.Infof("Received error from the name resolver: %v", err)
}
// The picker will not change since the balancer does not currently
// report an error. If the balancer hasn't received a single good resolver
// update yet, transition to TRANSIENT_FAILURE.
if b.state != connectivity.TransientFailure && b.addressList.size() > 0 {
if b.logger.V(2) {
b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.")
}
return
}
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
})
}
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
b.mu.Lock()
defer b.mu.Unlock()
b.cancelConnectionTimer()
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// Cleanup state pertaining to the previous resolver state.
// Treat an empty address list like an error by calling b.ResolverError.
b.closeSubConnsLocked()
b.addressList.updateAddrs(nil)
b.resolverErrorLocked(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState)
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var newAddrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for
// each of the endpoints, in order." - A61
for _, endpoint := range endpoints {
newAddrs = append(newAddrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
newAddrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
newAddrs = append([]resolver.Address{}, newAddrs...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}
// If an address appears in multiple endpoints or in the same endpoint
// multiple times, we keep it only once. We will create only one SubConn
// for the address because an AddressMap is used to store SubConns.
// Not de-duplicating would result in attempting to connect to the same
// SubConn multiple times in the same pass. We don't want this.
newAddrs = deDupAddresses(newAddrs)
newAddrs = interleaveAddresses(newAddrs)
prevAddr := b.addressList.currentAddress()
prevSCData, found := b.subConns.Get(prevAddr)
prevAddrsCount := b.addressList.size()
isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready
b.addressList.updateAddrs(newAddrs)
// If the previous ready SubConn exists in new address list,
// keep this connection and don't create new SubConns.
if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) {
return nil
}
b.reconcileSubConnsLocked(newAddrs)
// If it's the first resolver update or the balancer was already READY
// (but the new address list does not contain the ready SubConn) or
// CONNECTING, enter CONNECTING.
// We may be in TRANSIENT_FAILURE due to a previous empty address list,
// we should still enter CONNECTING because the sticky TF behaviour
// mentioned in A62 applies only when the TRANSIENT_FAILURE is reported
// due to connectivity failures.
if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 {
// Start connection attempt at first address.
b.forceUpdateConcludedStateLocked(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
b.startFirstPassLocked()
} else if b.state == connectivity.TransientFailure {
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
b.startFirstPassLocked()
}
return nil
}
// UpdateSubConnState is unused as a StateListener is always registered when
// creating SubConns.
func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state)
}
func (b *pickfirstBalancer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closeSubConnsLocked()
b.cancelConnectionTimer()
b.state = connectivity.Shutdown
}
// ExitIdle moves the balancer out of idle state. It can be called concurrently
// by the idlePicker and clientConn so access to variables should be
// synchronized.
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle {
b.startFirstPassLocked()
}
}
func (b *pickfirstBalancer) startFirstPassLocked() {
b.firstPass = true
b.numTF = 0
// Reset the connection attempt record for existing SubConns.
for _, sd := range b.subConns.Values() {
sd.connectionFailedInFirstPass = false
}
b.requestConnectionLocked()
}
func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() {
sd.subConn.Shutdown()
}
b.subConns = resolver.NewAddressMapV2[*scData]()
}
// deDupAddresses ensures that each address appears only once in the slice.
func deDupAddresses(addrs []resolver.Address) []resolver.Address {
seenAddrs := resolver.NewAddressMapV2[*scData]()
retAddrs := []resolver.Address{}
for _, addr := range addrs {
if _, ok := seenAddrs.Get(addr); ok {
continue
}
retAddrs = append(retAddrs, addr)
}
return retAddrs
}
// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6)
// as per RFC-8305 section 4.
// Whichever address family is first in the list is followed by an address of
// the other address family; that is, if the first address in the list is IPv6,
// then the first IPv4 address should be moved up in the list to be second in
// the list. It doesn't support configuring "First Address Family Count", i.e.
// there will always be a single member of the first address family at the
// beginning of the interleaved list.
// Addresses that are neither IPv4 nor IPv6 are treated as part of a third
// "unknown" family for interleaving.
// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6
func interleaveAddresses(addrs []resolver.Address) []resolver.Address {
familyAddrsMap := map[ipAddrFamily][]resolver.Address{}
interleavingOrder := []ipAddrFamily{}
for _, addr := range addrs {
family := addressFamily(addr.Addr)
if _, found := familyAddrsMap[family]; !found {
interleavingOrder = append(interleavingOrder, family)
}
familyAddrsMap[family] = append(familyAddrsMap[family], addr)
}
interleavedAddrs := make([]resolver.Address, 0, len(addrs))
for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) {
// Some IP types may have fewer addresses than others, so we look for
// the next type that has a remaining member to add to the interleaved
// list.
family := interleavingOrder[curFamilyIdx]
remainingMembers := familyAddrsMap[family]
if len(remainingMembers) > 0 {
interleavedAddrs = append(interleavedAddrs, remainingMembers[0])
familyAddrsMap[family] = remainingMembers[1:]
}
}
return interleavedAddrs
}
// addressFamily returns the ipAddrFamily after parsing the address string.
// If the address isn't of the format "ip-address:port", it returns
// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when
// using a resolver like passthrough where the address may be a hostname in
// some format that the dialer can resolve.
func addressFamily(address string) ipAddrFamily {
// Parse the IP after removing the port.
host, _, err := net.SplitHostPort(address)
if err != nil {
return ipAddrFamilyUnknown
}
ip, err := netip.ParseAddr(host)
if err != nil {
return ipAddrFamilyUnknown
}
switch {
case ip.Is4() || ip.Is4In6():
return ipAddrFamilyV4
case ip.Is6():
return ipAddrFamilyV6
default:
return ipAddrFamilyUnknown
}
}
// reconcileSubConnsLocked updates the active subchannels based on a new address
// list from the resolver. It does this by:
// - closing subchannels: any existing subchannels associated with addresses
// that are no longer in the updated list are shut down.
// - removing subchannels: entries for these closed subchannels are removed
// from the subchannel map.
//
// This ensures that the subchannel map accurately reflects the current set of
// addresses received from the name resolver.
func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) {
newAddrsMap := resolver.NewAddressMapV2[bool]()
for _, addr := range newAddrs {
newAddrsMap.Set(addr, true)
}
for _, oldAddr := range b.subConns.Keys() {
if _, ok := newAddrsMap.Get(oldAddr); ok {
continue
}
val, _ := b.subConns.Get(oldAddr)
val.subConn.Shutdown()
b.subConns.Delete(oldAddr)
}
}
// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn
// becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
b.cancelConnectionTimer()
for _, sd := range b.subConns.Values() {
if sd.subConn != selected.subConn {
sd.subConn.Shutdown()
}
}
b.subConns = resolver.NewAddressMapV2[*scData]()
b.subConns.Set(selected.addr, selected)
}
// requestConnectionLocked starts connecting on the subchannel corresponding to
// the current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted until a subchannel is found.
func (b *pickfirstBalancer) requestConnectionLocked() {
if !b.addressList.isValid() {
return
}
var lastErr error
for valid := true; valid; valid = b.addressList.increment() {
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
// We want to assign the new scData to sd from the outer scope,
// hence we can't use := below.
sd, err = b.newSCData(curAddr)
if err != nil {
// This should never happen, unless the clientConn is being shut
// down.
if b.logger.V(2) {
b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err)
}
// Do nothing, the LB policy will be closed soon.
return
}
b.subConns.Set(curAddr, sd)
}
switch sd.rawConnectivityState {
case connectivity.Idle:
sd.subConn.Connect()
b.scheduleNextConnectionLocked()
return
case connectivity.TransientFailure:
// The SubConn is being re-used and failed during a previous pass
// over the addressList. It has not completed backoff yet.
// Mark it as having failed and try the next address.
sd.connectionFailedInFirstPass = true
lastErr = sd.lastErr
continue
case connectivity.Connecting:
// Wait for the connection attempt to complete or the timer to fire
// before attempting the next address.
b.scheduleNextConnectionLocked()
return
default:
b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState)
return
}
}
// All the remaining addresses in the list are in TRANSIENT_FAILURE, end the
// first pass if possible.
b.endFirstPassIfPossibleLocked(lastErr)
}
func (b *pickfirstBalancer) scheduleNextConnectionLocked() {
b.cancelConnectionTimer()
if !b.addressList.hasNext() {
return
}
curAddr := b.addressList.currentAddress()
cancelled := false // Access to this is protected by the balancer's mutex.
closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() {
b.mu.Lock()
defer b.mu.Unlock()
// If the scheduled task is cancelled while acquiring the mutex, return.
if cancelled {
return
}
if b.logger.V(2) {
b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr)
}
if b.addressList.increment() {
b.requestConnectionLocked()
}
})
// Access to the cancellation callback held by the balancer is guarded by
// the balancer's mutex, so it's safe to set the boolean from the callback.
b.cancelConnectionTimer = sync.OnceFunc(func() {
cancelled = true
closeFn()
})
}
func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
oldState := sd.rawConnectivityState
sd.rawConnectivityState = newState.ConnectivityState
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes this
// SubConn.
if !b.isActiveSCData(sd) {
return
}
if newState.ConnectivityState == connectivity.Shutdown {
sd.effectiveState = connectivity.Shutdown
return
}
// Record a connection attempt when exiting CONNECTING.
if newState.ConnectivityState == connectivity.TransientFailure {
sd.connectionFailedInFirstPass = true
connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target)
}
if newState.ConnectivityState == connectivity.Ready {
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
b.shutdownRemainingLocked(sd)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one SubConn after
// entering READY. The SubConn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
if !b.healthCheckingEnabled {
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn)
}
sd.effectiveState = connectivity.Ready
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
return
}
if b.logger.V(2) {
b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn)
}
// Send a CONNECTING update to take the SubConn out of sticky-TF if
// required.
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) {
b.updateSubConnHealthState(sd, scs)
})
return
}
// If the LB policy is READY, and it receives a subchannel state change,
// it means that the READY subchannel has failed.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) {
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.shutdownRemainingLocked(sd)
sd.effectiveState = newState.ConnectivityState
// READY SubConn interspliced in between CONNECTING and IDLE, need to
// account for that.
if oldState == connectivity.Connecting {
// A known issue (https://github.com/grpc/grpc-go/issues/7862)
// causes a race that prevents the READY state change notification.
// This works around it.
connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target)
}
disconnectionsMetric.Record(b.metricsRecorder, 1, b.target)
b.addressList.reset()
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)},
})
return
}
if b.firstPass {
switch newState.ConnectivityState {
case connectivity.Connecting:
// The effective state can be in either IDLE, CONNECTING or
// TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in
// TRANSIENT_FAILURE until it's READY. See A62.
if sd.effectiveState != connectivity.TransientFailure {
sd.effectiveState = connectivity.Connecting
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
}
case connectivity.TransientFailure:
sd.lastErr = newState.ConnectionError
sd.effectiveState = connectivity.TransientFailure
// Since we're re-using common SubConns while handling resolver
// updates, we could receive an out of turn TRANSIENT_FAILURE from
// a pass over the previous address list. Happy Eyeballs will also
// cause out of order updates to arrive.
if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
b.cancelConnectionTimer()
if b.addressList.increment() {
b.requestConnectionLocked()
return
}
}
// End the first pass if we've seen a TRANSIENT_FAILURE from all
// SubConns once.
b.endFirstPassIfPossibleLocked(newState.ConnectionError)
}
return
}
// We have finished the first pass, keep re-connecting failing SubConns.
switch newState.ConnectivityState {
case connectivity.TransientFailure:
b.numTF = (b.numTF + 1) % b.subConns.Len()
sd.lastErr = newState.ConnectionError
if b.numTF%b.subConns.Len() == 0 {
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: newState.ConnectionError},
})
}
// We don't need to request re-resolution since the SubConn already
// does that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from SubConn into
// pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
}
// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the
// addresses are tried and their SubConns have reported a failure.
func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) {
// An optimization to avoid iterating over the entire SubConn map.
if b.addressList.isValid() {
return
}
// Connect() has been called on all the SubConns. The first pass can be
// ended if all the SubConns have reported a failure.
for _, sd := range b.subConns.Values() {
if !sd.connectionFailedInFirstPass {
return
}
}
b.firstPass = false
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: lastErr},
})
// Start re-connecting all the SubConns that are already in IDLE.
for _, sd := range b.subConns.Values() {
if sd.rawConnectivityState == connectivity.Idle {
sd.subConn.Connect()
}
}
}
func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool {
activeSD, found := b.subConns.Get(sd.addr)
return found && activeSD == sd
}
func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes
// this SubConn.
if !b.isActiveSCData(sd) {
return
}
sd.effectiveState = state.ConnectivityState
switch state.ConnectivityState {
case connectivity.Ready:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
case connectivity.TransientFailure:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)},
})
case connectivity.Connecting:
b.updateBalancerState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
default:
b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state)
}
}
// updateBalancerState stores the state reported to the channel and calls
// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate
// updates to the channel.
func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) {
// In case of TransientFailures allow the picker to be updated to update
// the connectivity error, in all other cases don't send duplicate state
// updates.
if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure {
return
}
b.forceUpdateConcludedStateLocked(newState)
}
// forceUpdateConcludedStateLocked stores the state reported to the channel and
// calls ClientConn.UpdateState().
// A separate function is defined to force update the ClientConn state since the
// channel doesn't correctly assume that LB policies start in CONNECTING and
// relies on LB policy to send an initial CONNECTING update.
func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) {
b.state = newState.ConnectivityState
b.cc.UpdateState(newState)
}
type picker struct {
result balancer.PickResult
err error
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// addressList manages sequentially iterating over addresses present in a list
// of endpoints. It provides a 1 dimensional view of the addresses present in
// the endpoints.
// This type is not safe for concurrent access.
type addressList struct {
addresses []resolver.Address
idx int
}
func (al *addressList) isValid() bool {
return al.idx < len(al.addresses)
}
func (al *addressList) size() int {
return len(al.addresses)
}
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
return false
}
al.idx++
return al.idx < len(al.addresses)
}
// currentAddress returns the current address pointed to in the addressList.
// If the list is in an invalid state, it returns an empty address instead.
func (al *addressList) currentAddress() resolver.Address {
if !al.isValid() {
return resolver.Address{}
}
return al.addresses[al.idx]
}
func (al *addressList) reset() {
al.idx = 0
}
func (al *addressList) updateAddrs(addrs []resolver.Address) {
al.addresses = addrs
al.reset()
}
// seekTo returns false if the needle was not found and the current index was
// left unchanged.
func (al *addressList) seekTo(needle resolver.Address) bool {
for ai, addr := range al.addresses {
if !equalAddressIgnoringBalAttributes(&addr, &needle) {
continue
}
al.idx = ai
return true
}
return false
}
// hasNext returns whether incrementing the addressList will result in moving
// past the end of the list. If the list has already moved past the end, it
// returns false.
func (al *addressList) hasNext() bool {
if !al.isValid() {
return false
}
return al.idx+1 < len(al.addresses)
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered
// equal. This is different from the Equal method on the resolver.Address type
// which considers all fields to determine equality. Here, we only consider
// fields that are meaningful to the SubConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes)
}

View File

@@ -26,7 +26,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
)
@@ -47,7 +47,7 @@ func (bb builder) Name() string {
}
func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
childBuilder := balancer.Get(pickfirstleaf.Name).Build
childBuilder := balancer.Get(pickfirst.Name).Build
bal := &rrBalancer{
cc: cc,
Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}),
@@ -67,6 +67,6 @@ func (b *rrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState),
ResolverState: pickfirst.EnableHealthListener(ccs.ResolverState),
})
}

View File

@@ -111,20 +111,6 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
}
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
// A Producer is a type shared among potentially many consumers. It is

View File

@@ -36,7 +36,6 @@ import (
)
var (
setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// noOpRegisterHealthListenerFn is used when client side health checking is
// disabled. It sends a single READY update on the registered listener.
noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() {
@@ -305,7 +304,7 @@ func newHealthData(s connectivity.State) *healthData {
// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
@@ -317,9 +316,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
// Invalidate the health listener by updating the healthData.
acbw.healthMu.Lock()
// A race may occur if a health listener is registered soon after the
@@ -450,13 +446,14 @@ func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func(
if acbw.ccb.cc.dopts.disableHealthCheck {
return noOpRegisterHealthListenerFn
}
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
return noOpRegisterHealthListenerFn
}
regHealthLisFn := internal.RegisterClientHealthCheckListener
if regHealthLisFn == nil {
// The health package is not imported.
return noOpRegisterHealthListenerFn
}
cfg := acbw.ac.cc.healthCheckConfig()
if cfg == nil {
channelz.Error(logger, acbw.ac.channelz, "Health check is requested but health package is not imported.")
return noOpRegisterHealthListenerFn
}
return func(ctx context.Context, listener func(balancer.SubConnState)) func() {

View File

@@ -18,7 +18,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc-gen-go v1.36.10
// protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto

View File

@@ -35,16 +35,19 @@ import (
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
expstats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle"
iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stats"
istats "google.golang.org/grpc/internal/stats"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
@@ -97,6 +100,41 @@ var (
errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
)
var (
disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.disconnections",
Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.",
Unit: "{disconnection}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality", "grpc.disconnect_error"},
Default: false,
})
connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.connection_attempts_succeeded",
Description: "EXPERIMENTAL. Number of successful connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"},
Default: false,
})
connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{
Name: "grpc.subchannel.connection_attempts_failed",
Description: "EXPERIMENTAL. Number of failed connection attempts.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"},
Default: false,
})
openConnectionsMetric = expstats.RegisterInt64UpDownCount(expstats.MetricDescriptor{
Name: "grpc.subchannel.open_connections",
Description: "EXPERIMENTAL. Number of open connections.",
Unit: "{attempt}",
Labels: []string{"grpc.target"},
OptionalLabels: []string{"grpc.lb.backend_service", "grpc.security_level", "grpc.lb.locality"},
Default: false,
})
)
const (
defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
defaultClientMaxSendMessageSize = math.MaxInt32
@@ -208,9 +246,10 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.pickerWrapper = newPickerWrapper()
cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.metricsRecorderList = istats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.statsHandler = istats.NewCombinedHandler(cc.dopts.copts.StatsHandlers...)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
@@ -260,9 +299,10 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
}()
// This creates the name resolver, load balancer, etc.
if err := cc.idlenessMgr.ExitIdleMode(); err != nil {
return nil, err
if err := cc.exitIdleMode(); err != nil {
return nil, fmt.Errorf("failed to exit idle mode: %w", err)
}
cc.idlenessMgr.UnsafeSetNotIdle()
// Return now for non-blocking dials.
if !cc.dopts.block {
@@ -330,7 +370,7 @@ func (cc *ClientConn) addTraceEvent(msg string) {
Severity: channelz.CtInfo,
}
}
channelz.AddTraceEvent(logger, cc.channelz, 0, ted)
channelz.AddTraceEvent(logger, cc.channelz, 1, ted)
}
type idler ClientConn
@@ -339,14 +379,17 @@ func (i *idler) EnterIdleMode() {
(*ClientConn)(i).enterIdleMode()
}
func (i *idler) ExitIdleMode() error {
return (*ClientConn)(i).exitIdleMode()
func (i *idler) ExitIdleMode() {
// Ignore the error returned from this method, because from the perspective
// of the caller (idleness manager), the channel would have always moved out
// of IDLE by the time this method returns.
(*ClientConn)(i).exitIdleMode()
}
// exitIdleMode moves the channel out of idle mode by recreating the name
// resolver and load balancer. This should never be called directly; use
// cc.idlenessMgr.ExitIdleMode instead.
func (cc *ClientConn) exitIdleMode() (err error) {
func (cc *ClientConn) exitIdleMode() error {
cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
@@ -354,11 +397,23 @@ func (cc *ClientConn) exitIdleMode() (err error) {
}
cc.mu.Unlock()
// Set state to CONNECTING before building the name resolver
// so the channel does not remain in IDLE.
cc.csMgr.updateState(connectivity.Connecting)
// This needs to be called without cc.mu because this builds a new resolver
// which might update state or report error inline, which would then need to
// acquire cc.mu.
if err := cc.resolverWrapper.start(); err != nil {
return err
// If resolver creation fails, treat it like an error reported by the
// resolver before any valid updates. Set channel's state to
// TransientFailure, and set an erroring picker with the resolver build
// error, which will returned as part of any subsequent RPCs.
logger.Warningf("Failed to start resolver: %v", err)
cc.csMgr.updateState(connectivity.TransientFailure)
cc.mu.Lock()
cc.updateResolverStateAndUnlock(resolver.State{}, err)
return fmt.Errorf("failed to start resolver: %w", err)
}
cc.addTraceEvent("exiting idle mode")
@@ -456,7 +511,7 @@ func (cc *ClientConn) validateTransportCredentials() error {
func (cc *ClientConn) channelzRegistration(target string) {
parentChannel, _ := cc.dopts.channelzParent.(*channelz.Channel)
cc.channelz = channelz.RegisterChannel(parentChannel, target)
cc.addTraceEvent("created")
cc.addTraceEvent(fmt.Sprintf("created for target %q", target))
}
// chainUnaryClientInterceptors chains all unary client interceptors into one.
@@ -621,7 +676,8 @@ type ClientConn struct {
channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager
metricsRecorderList *stats.MetricsRecorderList
metricsRecorderList *istats.MetricsRecorderList
statsHandler stats.Handler
// The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them.
@@ -678,10 +734,8 @@ func (cc *ClientConn) GetState() connectivity.State {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (cc *ClientConn) Connect() {
if err := cc.idlenessMgr.ExitIdleMode(); err != nil {
cc.addTraceEvent(err.Error())
return
}
cc.idlenessMgr.ExitIdleMode()
// If the ClientConn was not in idle mode, we need to call ExitIdle on the
// LB policy so that connections can be created.
cc.mu.Lock()
@@ -732,8 +786,8 @@ func init() {
internal.EnterIdleModeForTesting = func(cc *ClientConn) {
cc.idlenessMgr.EnterIdleModeForTesting()
}
internal.ExitIdleModeForTesting = func(cc *ClientConn) error {
return cc.idlenessMgr.ExitIdleMode()
internal.ExitIdleModeForTesting = func(cc *ClientConn) {
cc.idlenessMgr.ExitIdleMode()
}
}
@@ -858,6 +912,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}),
}
ac.updateTelemetryLabelsLocked()
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
// we connect to different addresses.
@@ -922,25 +977,24 @@ func (cc *ClientConn) incrCallsFailed() {
// connect starts creating a transport.
// It does nothing if the ac is not IDLE.
// TODO(bar) Move this to the addrConn section.
func (ac *addrConn) connect() error {
func (ac *addrConn) connect() {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
if logger.V(2) {
logger.Infof("connect called on shutdown addrConn; ignoring.")
}
ac.mu.Unlock()
return errConnClosing
return
}
if ac.state != connectivity.Idle {
if logger.V(2) {
logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state)
}
ac.mu.Unlock()
return nil
return
}
ac.resetTransportAndUnlock()
return nil
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
@@ -974,7 +1028,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
}
ac.addrs = addrs
ac.updateTelemetryLabelsLocked()
if ac.state == connectivity.Shutdown ||
ac.state == connectivity.TransientFailure ||
ac.state == connectivity.Idle {
@@ -1076,13 +1130,6 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig {
return cc.sc.healthCheckConfig
}
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) {
return cc.pickerWrapper.pick(ctx, failfast, balancer.PickInfo{
Ctx: ctx,
FullMethodName: method,
})
}
func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) {
if sc == nil {
// should never reach here.
@@ -1220,6 +1267,9 @@ type addrConn struct {
resetBackoff chan struct{}
channelz *channelz.SubChannel
localityLabel string
backendServiceLabel string
}
// Note: this requires a lock on ac.mu.
@@ -1227,6 +1277,18 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s {
return
}
// If we are transitioning out of Ready, it means there is a disconnection.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
// TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second
// part of the if condition below once the issue is fixed.
if ac.state == connectivity.Ready || (ac.state == connectivity.Connecting && s == connectivity.Idle) {
disconnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel, "unknown")
openConnectionsMetric.Record(ac.cc.metricsRecorderList, -1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel)
}
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
@@ -1234,7 +1296,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, ac.curAddr, lastErr)
ac.acbw.updateState(s, lastErr)
}
// adjustParams updates parameters used to create transports upon
@@ -1284,6 +1346,15 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.mu.Unlock()
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
if !errors.Is(err, context.Canceled) {
connectionAttemptsFailedMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel)
} else {
if logger.V(2) {
// This records cancelled connection attempts which can be later
// replaced by a metric.
logger.Infof("Context cancellation detected; not recording this as a failed connection attempt.")
}
}
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
// to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
@@ -1323,10 +1394,50 @@ func (ac *addrConn) resetTransportAndUnlock() {
}
// Success; reset backoff.
ac.mu.Lock()
connectionAttemptsSucceededMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel)
openConnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel)
ac.backoffIdx = 0
ac.mu.Unlock()
}
// updateTelemetryLabelsLocked calculates and caches the telemetry labels based on the
// first address in addrConn.
func (ac *addrConn) updateTelemetryLabelsLocked() {
labelsFunc, ok := internal.AddressToTelemetryLabels.(func(resolver.Address) map[string]string)
if !ok || len(ac.addrs) == 0 {
// Reset defaults
ac.localityLabel = ""
ac.backendServiceLabel = ""
return
}
labels := labelsFunc(ac.addrs[0])
ac.localityLabel = labels["grpc.lb.locality"]
ac.backendServiceLabel = labels["grpc.lb.backend_service"]
}
type securityLevelKey struct{}
func (ac *addrConn) securityLevelLocked() string {
var secLevel string
// During disconnection, ac.transport is nil. Fall back to the security level
// stored in the current address during connection.
if ac.transport == nil {
secLevel, _ = ac.curAddr.Attributes.Value(securityLevelKey{}).(string)
return secLevel
}
authInfo := ac.transport.Peer().AuthInfo
if ci, ok := authInfo.(interface {
GetCommonAuthInfo() credentials.CommonAuthInfo
}); ok {
secLevel = ci.GetCommonAuthInfo().SecurityLevel.String()
// Store the security level in the current address' attributes so
// that it remains available for disconnection metrics after the
// transport is closed.
ac.curAddr.Attributes = ac.curAddr.Attributes.WithValue(securityLevelKey{}, secLevel)
}
return secLevel
}
// tryAllAddrs tries to create a connection to the addresses, and stop when at
// the first successful one. It returns an error if no address was successfully
// connected, or updates ac appropriately with the new transport.
@@ -1416,25 +1527,26 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
}
ac.mu.Lock()
defer ac.mu.Unlock()
if ctx.Err() != nil {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// have been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()
// We unlock ac.mu because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
ac.mu.Unlock()
// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
//
// This can also happen when updateAddrs is called during a connection
// attempt.
go newTr.Close(transport.ErrConnClosing)
newTr.Close(transport.ErrConnClosing)
return nil
}
defer ac.mu.Unlock()
if hctx.Err() != nil {
// onClose was already called for this connection, but the connection
// was successfully established first. Consider it a success and set
@@ -1831,7 +1943,7 @@ func (cc *ClientConn) initAuthority() error {
} else if auth, ok := cc.resolverBuilder.(resolver.AuthorityOverrider); ok {
cc.authority = auth.OverrideAuthority(cc.parsedTarget)
} else if strings.HasPrefix(endpoint, ":") {
cc.authority = "localhost" + endpoint
cc.authority = "localhost" + encodeAuthority(endpoint)
} else {
cc.authority = encodeAuthority(endpoint)
}

View File

@@ -44,8 +44,7 @@ type PerRPCCredentials interface {
// A54). uri is the URI of the entry point for the request. When supported
// by the underlying implementation, ctx can be used for timeout and
// cancellation. Additionally, RequestInfo data will be available via ctx
// to this call. TODO(zhaoq): Define the set of the qualified keys instead
// of leaving it as an arbitrary string.
// to this call.
GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
// RequireTransportSecurity indicates whether the credentials requires
// transport security.
@@ -96,10 +95,11 @@ func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo {
return c
}
// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, server name, etc.
// ProtocolInfo provides static information regarding transport credentials.
type ProtocolInfo struct {
// ProtocolVersion is the gRPC wire protocol version.
//
// Deprecated: this is unused by gRPC.
ProtocolVersion string
// SecurityProtocol is the security protocol in use.
SecurityProtocol string
@@ -109,7 +109,16 @@ type ProtocolInfo struct {
//
// Deprecated: please use Peer.AuthInfo.
SecurityVersion string
// ServerName is the user-configured server name.
// ServerName is the user-configured server name. If set, this overrides
// the default :authority header used for all RPCs on the channel using the
// containing credentials, unless grpc.WithAuthority is set on the channel,
// in which case that setting will take precedence.
//
// This must be a valid `:authority` header according to
// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2).
//
// Deprecated: Users should use grpc.WithAuthority to override the authority
// on a channel instead of configuring the credentials.
ServerName string
}
@@ -173,12 +182,17 @@ type TransportCredentials interface {
// Clone makes a copy of this TransportCredentials.
Clone() TransportCredentials
// OverrideServerName specifies the value used for the following:
//
// - verifying the hostname on the returned certificates
// - as SNI in the client's handshake to support virtual hosting
// - as the value for `:authority` header at stream creation time
//
// Deprecated: use grpc.WithAuthority instead. Will be supported
// throughout 1.x.
// The provided string should be a valid `:authority` header according to
// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2).
//
// Deprecated: this method is unused by gRPC. Users should use
// grpc.WithAuthority to override the authority on a channel instead of
// configuring the credentials.
OverrideServerName(string) error
}

View File

@@ -56,9 +56,13 @@ func (t TLSInfo) AuthType() string {
// non-nil error if the validation fails.
func (t TLSInfo) ValidateAuthority(authority string) error {
var errs []error
host, _, err := net.SplitHostPort(authority)
if err != nil {
host = authority
}
for _, cert := range t.State.PeerCertificates {
var err error
if err = cert.VerifyHostname(authority); err == nil {
if err = cert.VerifyHostname(host); err == nil {
return nil
}
errs = append(errs, err)
@@ -110,14 +114,14 @@ func (c tlsCreds) Info() ProtocolInfo {
func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := credinternal.CloneTLSConfig(c.config)
if cfg.ServerName == "" {
serverName, _, err := net.SplitHostPort(authority)
if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority
}
cfg.ServerName = serverName
serverName, _, err := net.SplitHostPort(authority)
if err != nil {
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
serverName = authority
}
cfg.ServerName = serverName
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
@@ -259,9 +263,11 @@ func applyDefaults(c *tls.Config) *tls.Config {
// certificates to establish the identity of the client need to be included in
// the credentials (eg: for mTLS), use NewTLS instead, where a complete
// tls.Config can be specified.
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header
// field) in requests.
//
// serverNameOverride is for testing only. If set to a non empty string, it will
// override the virtual host name of authority (e.g. :authority header field) in
// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to
// override the authority of the client instead.
func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
}
@@ -271,9 +277,11 @@ func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) Transpor
// certificates to establish the identity of the client need to be included in
// the credentials (eg: for mTLS), use NewTLS instead, where a complete
// tls.Config can be specified.
// serverNameOverride is for testing only. If set to a non empty string,
// it will override the virtual host name of authority (e.g. :authority header
// field) in requests.
//
// serverNameOverride is for testing only. If set to a non empty string, it will
// override the virtual host name of authority (e.g. :authority header field) in
// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to
// override the authority of the client instead.
func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
b, err := os.ReadFile(certFile)
if err != nil {

View File

@@ -608,6 +608,8 @@ func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOpt
// WithAuthority returns a DialOption that specifies the value to be used as the
// :authority pseudo-header and as the server name in authentication handshake.
// This overrides all other ways of setting authority on the channel, but can be
// overridden per-call by using grpc.CallAuthority.
func WithAuthority(a string) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.authority = a

View File

@@ -27,8 +27,10 @@ package encoding
import (
"io"
"slices"
"strings"
"google.golang.org/grpc/encoding/internal"
"google.golang.org/grpc/internal/grpcutil"
)
@@ -36,12 +38,26 @@ import (
// It is intended for grpc internal use only.
const Identity = "identity"
func init() {
internal.RegisterCompressorForTesting = func(c Compressor) func() {
name := c.Name()
curCompressor, found := registeredCompressor[name]
RegisterCompressor(c)
return func() {
if found {
registeredCompressor[name] = curCompressor
return
}
delete(registeredCompressor, name)
grpcutil.RegisteredCompressorNames = slices.DeleteFunc(grpcutil.RegisteredCompressorNames, func(s string) bool {
return s == name
})
}
}
}
// Compressor is used for compressing and decompressing when sending or
// receiving messages.
//
// If a Compressor implements `DecompressedSize(compressedBytes []byte) int`,
// gRPC will invoke it to determine the size of the buffer allocated for the
// result of decompression. A return value of -1 indicates unknown size.
type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an
// error occurs while initializing the compressor, that error is returned

View File

@@ -0,0 +1,28 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package internal contains code internal to the encoding package.
package internal
// RegisterCompressorForTesting registers a compressor in the global compressor
// registry. It returns a cleanup function that should be called at the end
// of the test to unregister the compressor.
//
// This prevents compressors registered in one test from appearing in the
// encoding headers of subsequent tests.
var RegisterCompressorForTesting any // func RegisterCompressor(c Compressor) func()

View File

@@ -46,9 +46,25 @@ func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
}
// Important: if we remove this Size call then we cannot use
// UseCachedSize in MarshalOptions below.
size := proto.Size(vv)
// MarshalOptions with UseCachedSize allows reusing the result from the
// previous Size call. This is safe here because:
//
// 1. We just computed the size.
// 2. We assume the message is not being mutated concurrently.
//
// Important: If the proto.Size call above is removed, using UseCachedSize
// becomes unsafe and may lead to incorrect marshaling.
//
// For more details, see the doc of UseCachedSize:
// https://pkg.go.dev/google.golang.org/protobuf/proto#MarshalOptions
marshalOptions := proto.MarshalOptions{UseCachedSize: true}
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
buf, err := marshalOptions.Marshal(vv)
if err != nil {
return nil, err
}
@@ -56,7 +72,7 @@ func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
if _, err := marshalOptions.MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
}

View File

@@ -75,6 +75,8 @@ const (
MetricTypeIntHisto
MetricTypeFloatHisto
MetricTypeIntGauge
MetricTypeIntUpDownCount
MetricTypeIntAsyncGauge
)
// Int64CountHandle is a typed handle for a int count metric. This handle
@@ -93,6 +95,23 @@ func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels .
recorder.RecordInt64Count(h, incr, labels...)
}
// Int64UpDownCountHandle is a typed handle for an int up-down counter metric.
// This handle is passed at the recording point in order to know which metric
// to record on.
type Int64UpDownCountHandle MetricDescriptor
// Descriptor returns the int64 up-down counter handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64UpDownCountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 up-down counter value on the metrics recorder provided.
// The value 'v' can be positive to increment or negative to decrement.
func (h *Int64UpDownCountHandle) Record(recorder MetricsRecorder, v int64, labels ...string) {
recorder.RecordInt64UpDownCount(h, v, labels...)
}
// Float64CountHandle is a typed handle for a float count metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Float64CountHandle MetricDescriptor
@@ -154,6 +173,30 @@ func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels .
recorder.RecordInt64Gauge(h, incr, labels...)
}
// AsyncMetric is a marker interface for asynchronous metric types.
type AsyncMetric interface {
isAsync()
Descriptor() *MetricDescriptor
}
// Int64AsyncGaugeHandle is a typed handle for an int gauge metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Int64AsyncGaugeHandle MetricDescriptor
// isAsync implements the AsyncMetric interface.
func (h *Int64AsyncGaugeHandle) isAsync() {}
// Descriptor returns the int64 gauge handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64AsyncGaugeHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 gauge value on the metrics recorder provided.
func (h *Int64AsyncGaugeHandle) Record(recorder AsyncMetricsRecorder, value int64, labels ...string) {
recorder.RecordInt64AsyncGauge(h, value, labels...)
}
// registeredMetrics are the registered metric descriptor names.
var registeredMetrics = make(map[string]bool)
@@ -249,6 +292,35 @@ func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle {
return (*Int64GaugeHandle)(descPtr)
}
// RegisterInt64UpDownCount registers the metric description onto the global registry.
// It returns a typed handle to use for recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64UpDownCount(descriptor MetricDescriptor) *Int64UpDownCountHandle {
registerMetric(descriptor.Name, descriptor.Default)
// Set the specific metric type for the up-down counter
descriptor.Type = MetricTypeIntUpDownCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64UpDownCountHandle)(descPtr)
}
// RegisterInt64AsyncGauge registers the metric description onto the global registry.
// It returns a typed handle to use for recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64AsyncGauge(descriptor MetricDescriptor) *Int64AsyncGaugeHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntAsyncGauge
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64AsyncGaugeHandle)(descPtr)
}
// snapshotMetricsRegistryForTesting snapshots the global data of the metrics
// registry. Returns a cleanup function that sets the metrics registry to its
// original state.

View File

@@ -19,9 +19,13 @@
// Package stats contains experimental metrics/stats API's.
package stats
import "google.golang.org/grpc/stats"
import (
"google.golang.org/grpc/internal"
"google.golang.org/grpc/stats"
)
// MetricsRecorder records on metrics derived from metric registry.
// Implementors must embed UnimplementedMetricsRecorder.
type MetricsRecorder interface {
// RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle.
@@ -38,6 +42,49 @@ type MetricsRecorder interface {
// RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string)
// RecordInt64UpDownCounter records the measurement alongside labels on the int
// count associated with the provided handle.
RecordInt64UpDownCount(handle *Int64UpDownCountHandle, incr int64, labels ...string)
// RegisterAsyncReporter registers a reporter to produce metric values for
// only the listed descriptors. The returned function must be called when
// the metrics are no longer needed, which will remove the reporter. The
// returned method needs to be idempotent and concurrent safe.
RegisterAsyncReporter(reporter AsyncMetricReporter, descriptors ...AsyncMetric) func()
// EnforceMetricsRecorderEmbedding is included to force implementers to embed
// another implementation of this interface, allowing gRPC to add methods
// without breaking users.
internal.EnforceMetricsRecorderEmbedding
}
// AsyncMetricReporter is an interface for types that record metrics asynchronously
// for the set of descriptors they are registered with. The AsyncMetricsRecorder
// parameter is used to record values for these metrics.
//
// Implementations must make unique recordings across all registered
// AsyncMetricReporters. Meaning, they should not report values for a metric with
// the same attributes as another AsyncMetricReporter will report.
//
// Implementations must be concurrent-safe.
type AsyncMetricReporter interface {
// Report records metric values using the provided recorder.
Report(AsyncMetricsRecorder) error
}
// AsyncMetricReporterFunc is an adapter to allow the use of ordinary functions as
// AsyncMetricReporters.
type AsyncMetricReporterFunc func(AsyncMetricsRecorder) error
// Report calls f(r).
func (f AsyncMetricReporterFunc) Report(r AsyncMetricsRecorder) error {
return f(r)
}
// AsyncMetricsRecorder records on asynchronous metrics derived from metric registry.
type AsyncMetricsRecorder interface {
// RecordInt64AsyncGauge records the measurement alongside labels on the int
// count associated with the provided handle asynchronously
RecordInt64AsyncGauge(handle *Int64AsyncGaugeHandle, incr int64, labels ...string)
}
// Metrics is an experimental legacy alias of the now-stable stats.MetricSet.
@@ -52,3 +99,33 @@ type Metric = string
func NewMetrics(metrics ...Metric) *Metrics {
return stats.NewMetricSet(metrics...)
}
// UnimplementedMetricsRecorder must be embedded to have forward compatible implementations.
type UnimplementedMetricsRecorder struct {
internal.EnforceMetricsRecorderEmbedding
}
// RecordInt64Count provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Count(*Int64CountHandle, int64, ...string) {}
// RecordFloat64Count provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordFloat64Count(*Float64CountHandle, float64, ...string) {}
// RecordInt64Histo provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Histo(*Int64HistoHandle, int64, ...string) {}
// RecordFloat64Histo provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordFloat64Histo(*Float64HistoHandle, float64, ...string) {}
// RecordInt64Gauge provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64Gauge(*Int64GaugeHandle, int64, ...string) {}
// RecordInt64UpDownCount provides a no-op implementation.
func (UnimplementedMetricsRecorder) RecordInt64UpDownCount(*Int64UpDownCountHandle, int64, ...string) {
}
// RegisterAsyncReporter provides a no-op implementation.
func (UnimplementedMetricsRecorder) RegisterAsyncReporter(AsyncMetricReporter, ...AsyncMetric) func() {
// No-op: Return an empty function to ensure caller doesn't panic on nil function call
return func() {}
}

View File

@@ -97,8 +97,12 @@ type StreamServerInfo struct {
IsServerStream bool
}
// StreamServerInterceptor provides a hook to intercept the execution of a streaming RPC on the server.
// info contains all the information of this RPC the interceptor can operate on. And handler is the
// service method implementation. It is the responsibility of the interceptor to invoke handler to
// complete the RPC.
// StreamServerInterceptor provides a hook to intercept the execution of a
// streaming RPC on the server.
//
// srv is the service implementation on which the RPC was invoked, and needs to
// be passed to handler, and not used otherwise. ss is the server side of the
// stream. info contains all the information of this RPC the interceptor can
// operate on. And handler is the service method implementation. It is the
// responsibility of the interceptor to invoke handler to complete the RPC.
type StreamServerInterceptor func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error

View File

@@ -67,6 +67,10 @@ type Balancer struct {
// balancerCurrent before the UpdateSubConnState is called on the
// balancerCurrent.
currentMu sync.Mutex
// activeGoroutines tracks all the goroutines that this balancer has started
// and that should be waited on when the balancer closes.
activeGoroutines sync.WaitGroup
}
// swap swaps out the current lb with the pending lb and updates the ClientConn.
@@ -76,7 +80,9 @@ func (gsb *Balancer) swap() {
cur := gsb.balancerCurrent
gsb.balancerCurrent = gsb.balancerPending
gsb.balancerPending = nil
gsb.activeGoroutines.Add(1)
go func() {
defer gsb.activeGoroutines.Done()
gsb.currentMu.Lock()
defer gsb.currentMu.Unlock()
cur.Close()
@@ -274,6 +280,7 @@ func (gsb *Balancer) Close() {
currentBalancerToClose.Close()
pendingBalancerToClose.Close()
gsb.activeGoroutines.Wait()
}
// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer
@@ -324,7 +331,12 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) {
defer bw.gsb.mu.Unlock()
bw.lastState = state
// If Close() acquires the mutex before UpdateState(), the balancer
// will already have been removed from the current or pending state when
// reaching this point.
if !bw.gsb.balancerCurrentOrPending(bw) {
// Returning here ensures that (*Balancer).swap() is not invoked after
// (*Balancer).Close() and therefore prevents "use after close".
return
}

View File

@@ -0,0 +1,66 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package weight contains utilities to manage endpoint weights. Weights are
// used by LB policies such as ringhash to distribute load across multiple
// endpoints.
package weight
import (
"fmt"
"google.golang.org/grpc/resolver"
)
// attributeKey is the type used as the key to store EndpointInfo in the
// Attributes field of resolver.Endpoint.
type attributeKey struct{}
// EndpointInfo will be stored in the Attributes field of Endpoints in order to
// use the ringhash balancer.
type EndpointInfo struct {
Weight uint32
}
// Equal allows the values to be compared by Attributes.Equal.
func (a EndpointInfo) Equal(o any) bool {
oa, ok := o.(EndpointInfo)
return ok && oa.Weight == a.Weight
}
// Set returns a copy of endpoint in which the Attributes field is updated with
// EndpointInfo.
func Set(endpoint resolver.Endpoint, epInfo EndpointInfo) resolver.Endpoint {
endpoint.Attributes = endpoint.Attributes.WithValue(attributeKey{}, epInfo)
return endpoint
}
// String returns a human-readable representation of EndpointInfo.
// This method is intended for logging, testing, and debugging purposes only.
// Do not rely on the output format, as it is not guaranteed to remain stable.
func (a EndpointInfo) String() string {
return fmt.Sprintf("Weight: %d", a.Weight)
}
// FromEndpoint returns the EndpointInfo stored in the Attributes field of an
// endpoint. It returns an empty EndpointInfo if attribute is not found.
func FromEndpoint(endpoint resolver.Endpoint) EndpointInfo {
v := endpoint.Attributes.Value(attributeKey{})
ei, _ := v.(EndpointInfo)
return ei
}

View File

@@ -83,6 +83,7 @@ func (b *Unbounded) Load() {
default:
}
} else if b.closing && !b.closed {
b.closed = true
close(b.c)
}
}

View File

@@ -194,7 +194,7 @@ func (r RefChannelType) String() string {
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, e Entity, depth int, desc *TraceEvent) {
// Log only the trace description associated with the bottom most entity.
d := fmt.Sprintf("[%s]%s", e, desc.Desc)
d := fmt.Sprintf("[%s] %s", e, desc.Desc)
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, d)

View File

@@ -26,31 +26,31 @@ import (
)
var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
// EnableTXTServiceConfig is set if the DNS resolver should perform TXT
// lookups for service config ("GRPC_ENABLE_TXT_SERVICE_CONFIG" is not
// "false").
EnableTXTServiceConfig = boolFromEnv("GRPC_ENABLE_TXT_SERVICE_CONFIG", true)
// TXTErrIgnore is set if TXT errors should be ignored
// ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true)
// RingHashCap indicates the maximum ring size which defaults to 4096
// entries but may be overridden by setting the environment variable
// "GRPC_RING_HASH_CAP". This does not override the default bounds
// checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M).
RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024)
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true)
// XDSFallbackSupport is the env variable that controls whether support for
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", true)
// NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
// instead of the exiting pickfirst implementation. This can be disabled by
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"
// to "false".
NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true)
// XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash
// key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by
@@ -69,6 +69,41 @@ var (
// ALTSHandshakerKeepaliveParams is set if we should add the
// KeepaliveParams when dial the ALTS handshaker service.
ALTSHandshakerKeepaliveParams = boolFromEnv("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false)
// EnableDefaultPortForProxyTarget controls whether the resolver adds a default port 443
// to a target address that lacks one. This flag only has an effect when all of
// the following conditions are met:
// - A connect proxy is being used.
// - Target resolution is disabled.
// - The DNS resolver is being used.
EnableDefaultPortForProxyTarget = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_DEFAULT_PORT_FOR_PROXY_TARGET", true)
// XDSAuthorityRewrite indicates whether xDS authority rewriting is enabled.
// This feature is defined in gRFC A81 and is enabled by setting the
// environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true".
XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)
// PickFirstWeightedShuffling indicates whether weighted endpoint shuffling
// is enabled in the pick_first LB policy, as defined in gRFC A113. This
// feature can be disabled by setting the environment variable
// GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false".
PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true)
// DisableStrictPathChecking indicates whether strict path checking is
// disabled. This feature can be disabled by setting the environment
// variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to "true".
//
// When strict path checking is enabled, gRPC will reject requests with
// paths that do not conform to the gRPC over HTTP/2 specification found at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md.
//
// When disabled, gRPC will allow paths that do not contain a leading slash.
// Enabling strict path checking is recommended for security reasons, as it
// prevents potential path traversal vulnerabilities.
//
// A future release will remove this environment variable, enabling strict
// path checking behavior unconditionally.
DisableStrictPathChecking = boolFromEnv("GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING", false)
)
func boolFromEnv(envVar string, def bool) bool {

View File

@@ -68,4 +68,15 @@ var (
// trust. For more details, see:
// https://github.com/grpc/proposal/blob/master/A87-mtls-spiffe-support.md
XDSSPIFFEEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_MTLS_SPIFFE", false)
// XDSHTTPConnectEnabled is true if gRPC should parse custom Metadata
// configuring use of an HTTP CONNECT proxy via xDS from cluster resources.
// For more details, see:
// https://github.com/grpc/proposal/blob/master/A86-xds-http-connect.md
XDSHTTPConnectEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false)
// XDSBootstrapCallCredsEnabled controls if call credentials can be used in
// xDS bootstrap configuration via the `call_creds` field. For more details,
// see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md
XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false)
)

View File

@@ -25,4 +25,11 @@ var (
// BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server.
BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
// SetDefaultBufferPool updates the default buffer pool.
SetDefaultBufferPool any // func(mem.BufferPool)
// AcceptCompressors is implemented by the grpc package and returns
// a call option that restricts the grpc-accept-encoding header for a call.
AcceptCompressors any // func(...string) grpc.CallOption
)

View File

@@ -80,25 +80,11 @@ func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure
func (cs *CallbackSerializer) run(ctx context.Context) {
defer close(cs.done)
// TODO: when Go 1.21 is the oldest supported version, this loop and Close
// can be replaced with:
//
// context.AfterFunc(ctx, cs.callbacks.Close)
for ctx.Err() == nil {
select {
case <-ctx.Done():
// Do nothing here. Next iteration of the for loop will not happen,
// since ctx.Err() would be non-nil.
case cb := <-cs.callbacks.Get():
cs.callbacks.Load()
cb.(func(context.Context))(ctx)
}
}
// Close the buffer when the context is canceled
// to prevent new callbacks from being added.
context.AfterFunc(ctx, cs.callbacks.Close)
// Close the buffer to prevent new callbacks from being added.
cs.callbacks.Close()
// Run all pending callbacks.
// Run all callbacks.
for cb := range cs.callbacks.Get() {
cs.callbacks.Load()
cb.(func(context.Context))(ctx)

View File

@@ -21,7 +21,6 @@
package idle
import (
"fmt"
"math"
"sync"
"sync/atomic"
@@ -33,15 +32,15 @@ var timeAfterFunc = func(d time.Duration, f func()) *time.Timer {
return time.AfterFunc(d, f)
}
// Enforcer is the functionality provided by grpc.ClientConn to enter
// and exit from idle mode.
type Enforcer interface {
ExitIdleMode() error
// ClientConn is the functionality provided by grpc.ClientConn to enter and exit
// from idle mode.
type ClientConn interface {
ExitIdleMode()
EnterIdleMode()
}
// Manager implements idleness detection and calls the configured Enforcer to
// enter/exit idle mode when appropriate. Must be created by NewManager.
// Manager implements idleness detection and calls the ClientConn to enter/exit
// idle mode when appropriate. Must be created by NewManager.
type Manager struct {
// State accessed atomically.
lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed.
@@ -51,8 +50,8 @@ type Manager struct {
// Can be accessed without atomics or mutex since these are set at creation
// time and read-only after that.
enforcer Enforcer // Functionality provided by grpc.ClientConn.
timeout time.Duration
cc ClientConn // Functionality provided by grpc.ClientConn.
timeout time.Duration
// idleMu is used to guarantee mutual exclusion in two scenarios:
// - Opposing intentions:
@@ -72,9 +71,9 @@ type Manager struct {
// NewManager creates a new idleness manager implementation for the
// given idle timeout. It begins in idle mode.
func NewManager(enforcer Enforcer, timeout time.Duration) *Manager {
func NewManager(cc ClientConn, timeout time.Duration) *Manager {
return &Manager{
enforcer: enforcer,
cc: cc,
timeout: timeout,
actuallyIdle: true,
activeCallsCount: -math.MaxInt32,
@@ -127,7 +126,7 @@ func (m *Manager) handleIdleTimeout() {
// Now that we've checked that there has been no activity, attempt to enter
// idle mode, which is very likely to succeed.
if m.tryEnterIdleMode() {
if m.tryEnterIdleMode(true) {
// Successfully entered idle mode. No timer needed until we exit idle.
return
}
@@ -142,10 +141,13 @@ func (m *Manager) handleIdleTimeout() {
// that, it performs a last minute check to ensure that no new RPC has come in,
// making the channel active.
//
// checkActivity controls if a check for RPC activity, since the last time the
// idle_timeout fired, is made.
// Return value indicates whether or not the channel moved to idle mode.
//
// Holds idleMu which ensures mutual exclusion with exitIdleMode.
func (m *Manager) tryEnterIdleMode() bool {
func (m *Manager) tryEnterIdleMode(checkActivity bool) bool {
// Setting the activeCallsCount to -math.MaxInt32 indicates to OnCallBegin()
// that the channel is either in idle mode or is trying to get there.
if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) {
@@ -166,7 +168,7 @@ func (m *Manager) tryEnterIdleMode() bool {
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
return false
}
if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 {
if checkActivity && atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 {
// A very short RPC could have come in (and also finished) after we
// checked for calls count and activity in handleIdleTimeout(), but
// before the CAS operation. So, we need to check for activity again.
@@ -177,44 +179,37 @@ func (m *Manager) tryEnterIdleMode() bool {
// No new RPCs have come in since we set the active calls count value to
// -math.MaxInt32. And since we have the lock, it is safe to enter idle mode
// unconditionally now.
m.enforcer.EnterIdleMode()
m.cc.EnterIdleMode()
m.actuallyIdle = true
return true
}
// EnterIdleModeForTesting instructs the channel to enter idle mode.
func (m *Manager) EnterIdleModeForTesting() {
m.tryEnterIdleMode()
m.tryEnterIdleMode(false)
}
// OnCallBegin is invoked at the start of every RPC.
func (m *Manager) OnCallBegin() error {
func (m *Manager) OnCallBegin() {
if m.isClosed() {
return nil
return
}
if atomic.AddInt32(&m.activeCallsCount, 1) > 0 {
// Channel is not idle now. Set the activity bit and allow the call.
atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1)
return nil
return
}
// Channel is either in idle mode or is in the process of moving to idle
// mode. Attempt to exit idle mode to allow this RPC.
if err := m.ExitIdleMode(); err != nil {
// Undo the increment to calls count, and return an error causing the
// RPC to fail.
atomic.AddInt32(&m.activeCallsCount, -1)
return err
}
m.ExitIdleMode()
atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1)
return nil
}
// ExitIdleMode instructs m to call the enforcer's ExitIdleMode and update m's
// ExitIdleMode instructs m to call the ClientConn's ExitIdleMode and update its
// internal state.
func (m *Manager) ExitIdleMode() error {
func (m *Manager) ExitIdleMode() {
// Holds idleMu which ensures mutual exclusion with tryEnterIdleMode.
m.idleMu.Lock()
defer m.idleMu.Unlock()
@@ -231,12 +226,10 @@ func (m *Manager) ExitIdleMode() error {
// m.ExitIdleMode.
//
// In any case, there is nothing to do here.
return nil
return
}
if err := m.enforcer.ExitIdleMode(); err != nil {
return fmt.Errorf("failed to exit idle mode: %w", err)
}
m.cc.ExitIdleMode()
// Undo the idle entry process. This also respects any new RPC attempts.
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
@@ -244,7 +237,23 @@ func (m *Manager) ExitIdleMode() error {
// Start a new timer to fire after the configured idle timeout.
m.resetIdleTimerLocked(m.timeout)
return nil
}
// UnsafeSetNotIdle instructs the Manager to update its internal state to
// reflect the reality that the channel is no longer in IDLE mode.
//
// N.B. This method is intended only for internal use by the gRPC client
// when it exits IDLE mode **manually** from `Dial`. The callsite must ensure:
// - The channel was **actually in IDLE mode** immediately prior to the call.
// - There is **no concurrent activity** that could cause the channel to exit
// IDLE mode *naturally* at the same time.
func (m *Manager) UnsafeSetNotIdle() {
m.idleMu.Lock()
defer m.idleMu.Unlock()
atomic.AddInt32(&m.activeCallsCount, math.MaxInt32)
m.actuallyIdle = false
m.resetIdleTimerLocked(m.timeout)
}
// OnCallEnd is invoked at the end of every RPC.

View File

@@ -182,35 +182,6 @@ var (
// other features, including the CSDS service.
NewXDSResolverWithClientForTesting any // func(xdsclient.XDSClient) (resolver.Builder, error)
// RegisterRLSClusterSpecifierPluginForTesting registers the RLS Cluster
// Specifier Plugin for testing purposes, regardless of the XDSRLS environment
// variable.
//
// TODO: Remove this function once the RLS env var is removed.
RegisterRLSClusterSpecifierPluginForTesting func()
// UnregisterRLSClusterSpecifierPluginForTesting unregisters the RLS Cluster
// Specifier Plugin for testing purposes. This is needed because there is no way
// to unregister the RLS Cluster Specifier Plugin after registering it solely
// for testing purposes using RegisterRLSClusterSpecifierPluginForTesting().
//
// TODO: Remove this function once the RLS env var is removed.
UnregisterRLSClusterSpecifierPluginForTesting func()
// RegisterRBACHTTPFilterForTesting registers the RBAC HTTP Filter for testing
// purposes, regardless of the RBAC environment variable.
//
// TODO: Remove this function once the RBAC env var is removed.
RegisterRBACHTTPFilterForTesting func()
// UnregisterRBACHTTPFilterForTesting unregisters the RBAC HTTP Filter for
// testing purposes. This is needed because there is no way to unregister the
// HTTP Filter after registering it solely for testing purposes using
// RegisterRBACHTTPFilterForTesting().
//
// TODO: Remove this function once the RBAC env var is removed.
UnregisterRBACHTTPFilterForTesting func()
// ORCAAllowAnyMinReportingInterval is for examples/orca use ONLY.
ORCAAllowAnyMinReportingInterval any // func(so *orca.ServiceOptions)
@@ -240,22 +211,11 @@ var (
// default resolver scheme.
UserSetDefaultScheme = false
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
// SnapshotMetricRegistryForTesting snapshots the global data of the metric
// registry. Returns a cleanup function that sets the metric registry to its
// original state. Only called in testing functions.
SnapshotMetricRegistryForTesting func() func()
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
// testing purposes.
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes.
SetBufferPoolingThresholdForTesting any // func(int)
@@ -273,6 +233,18 @@ var (
// When set, the function will be called before the stream enters
// the blocking state.
NewStreamWaitingForResolver = func() {}
// AddressToTelemetryLabels is an xDS-provided function to extract telemetry
// labels from a resolver.Address. Callers must assert its type before calling.
AddressToTelemetryLabels any // func(addr resolver.Address) map[string]string
// AsyncReporterCleanupDelegate is initialized to a pass-through function by
// default (production behavior), allowing tests to swap it with an
// implementation which tracks registration of async reporter and its
// corresponding cleanup.
AsyncReporterCleanupDelegate = func(cleanup func()) func() {
return cleanup
}
)
// HealthChecker defines the signature of the client-side LB channel health
@@ -320,3 +292,9 @@ type EnforceClientConnEmbedding interface {
type Timer interface {
Stop() bool
}
// EnforceMetricsRecorderEmbedding is used to enforce proper MetricsRecorder
// implementation embedding.
type EnforceMetricsRecorderEmbedding interface {
enforceMetricsRecorderEmbedding()
}

View File

@@ -22,11 +22,13 @@ package delegatingresolver
import (
"fmt"
"net"
"net/http"
"net/url"
"sync"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/internal/transport/networktype"
@@ -40,6 +42,8 @@ var (
HTTPSProxyFromEnvironment = http.ProxyFromEnvironment
)
const defaultPort = "443"
// delegatingResolver manages both target URI and proxy address resolution by
// delegating these tasks to separate child resolvers. Essentially, it acts as
// an intermediary between the gRPC ClientConn and the child resolvers.
@@ -107,10 +111,18 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
targetResolver: nopResolver{},
}
addr := target.Endpoint()
var err error
r.proxyURL, err = proxyURLForTarget(target.Endpoint())
if target.URL.Scheme == "dns" && !targetResolutionEnabled && envconfig.EnableDefaultPortForProxyTarget {
addr, err = parseTarget(addr)
if err != nil {
return nil, fmt.Errorf("delegating_resolver: invalid target address %q: %v", target.Endpoint(), err)
}
}
r.proxyURL, err = proxyURLForTarget(addr)
if err != nil {
return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %s: %v", target, err)
return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %q: %v", target, err)
}
// proxy is not configured or proxy address excluded using `NO_PROXY` env
@@ -132,8 +144,8 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
// bypass the target resolver and store the unresolved target address.
if target.URL.Scheme == "dns" && !targetResolutionEnabled {
r.targetResolverState = &resolver.State{
Addresses: []resolver.Address{{Addr: target.Endpoint()}},
Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: target.Endpoint()}}}},
Addresses: []resolver.Address{{Addr: addr}},
Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: addr}}}},
}
r.updateTargetResolverState(*r.targetResolverState)
return r, nil
@@ -202,6 +214,44 @@ func needsProxyResolver(state *resolver.State) bool {
return false
}
// parseTarget takes a target string and ensures it is a valid "host:port" target.
//
// It does the following:
// 1. If the target already has a port (e.g., "host:port", "[ipv6]:port"),
// it is returned as is.
// 2. If the host part is empty (e.g., ":80"), it defaults to "localhost",
// returning "localhost:80".
// 3. If the target is missing a port (e.g., "host", "ipv6"), the defaultPort
// is added.
//
// An error is returned for empty targets or targets with a trailing colon
// but no port (e.g., "host:").
func parseTarget(target string) (string, error) {
if target == "" {
return "", fmt.Errorf("missing address")
}
host, port, err := net.SplitHostPort(target)
if err != nil {
// If SplitHostPort fails, it's likely because the port is missing.
// We append the default port and return the result.
return net.JoinHostPort(target, defaultPort), nil
}
// If SplitHostPort succeeds, we check for edge cases.
if port == "" {
// A success with an empty port means the target had a trailing colon,
// e.g., "host:", which is an error.
return "", fmt.Errorf("missing port after port-separator colon")
}
if host == "" {
// A success with an empty host means the target was like ":80".
// We default the host to "localhost".
host = "localhost"
}
return net.JoinHostPort(host, port), nil
}
func skipProxy(address resolver.Address) bool {
// Avoid proxy when network is not tcp.
networkType, ok := networktype.Get(address)

View File

@@ -125,20 +125,23 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts
// IP address.
if ipAddr, err := formatIP(host); err == nil {
addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
cc.UpdateState(resolver.State{Addresses: addr})
cc.UpdateState(resolver.State{
Addresses: addr,
Endpoints: []resolver.Endpoint{{Addresses: addr}},
})
return deadResolver{}, nil
}
// DNS address (non-IP).
ctx, cancel := context.WithCancel(context.Background())
d := &dnsResolver{
host: host,
port: port,
ctx: ctx,
cancel: cancel,
cc: cc,
rn: make(chan struct{}, 1),
disableServiceConfig: opts.DisableServiceConfig,
host: host,
port: port,
ctx: ctx,
cancel: cancel,
cc: cc,
rn: make(chan struct{}, 1),
enableServiceConfig: envconfig.EnableTXTServiceConfig && !opts.DisableServiceConfig,
}
d.resolver, err = internal.NewNetResolver(target.URL.Host)
@@ -181,8 +184,8 @@ type dnsResolver struct {
// finishes, race detector sometimes will warn lookup (READ the lookup
// function pointers) inside watcher() goroutine has data race with
// replaceNetFunc (WRITE the lookup function pointers).
wg sync.WaitGroup
disableServiceConfig bool
wg sync.WaitGroup
enableServiceConfig bool
}
// ResolveNow invoke an immediate resolution of the target that this
@@ -342,11 +345,19 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
return nil, hostErr
}
state := resolver.State{Addresses: addrs}
eps := make([]resolver.Endpoint, 0, len(addrs))
for _, addr := range addrs {
eps = append(eps, resolver.Endpoint{Addresses: []resolver.Address{addr}})
}
state := resolver.State{
Addresses: addrs,
Endpoints: eps,
}
if len(srv) > 0 {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
}
if !d.disableServiceConfig {
if d.enableServiceConfig {
state.ServiceConfig = d.lookupTXT(ctx)
}
return &state, nil

View File

@@ -20,6 +20,7 @@ import (
"fmt"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/stats"
)
@@ -28,6 +29,7 @@ import (
// It eats any record calls where the label values provided do not match the
// number of label keys.
type MetricsRecorderList struct {
internal.EnforceMetricsRecorderEmbedding
// metricsRecorders are the metrics recorders this list will forward to.
metricsRecorders []estats.MetricsRecorder
}
@@ -64,6 +66,16 @@ func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle,
}
}
// RecordInt64UpDownCount records the measurement alongside labels on the int
// count associated with the provided handle.
func (l *MetricsRecorderList) RecordInt64UpDownCount(handle *estats.Int64UpDownCountHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64UpDownCount(handle, incr, labels...)
}
}
// RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle.
func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) {
@@ -103,3 +115,61 @@ func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle,
metricRecorder.RecordInt64Gauge(handle, incr, labels...)
}
}
// RegisterAsyncReporter forwards the registration to all underlying metrics
// recorders.
//
// It returns a cleanup function that, when called, invokes the cleanup function
// returned by each underlying recorder, ensuring the reporter is unregistered
// from all of them.
func (l *MetricsRecorderList) RegisterAsyncReporter(reporter estats.AsyncMetricReporter, metrics ...estats.AsyncMetric) func() {
descriptorsMap := make(map[*estats.MetricDescriptor]bool, len(metrics))
for _, m := range metrics {
descriptorsMap[m.Descriptor()] = true
}
unregisterFns := make([]func(), 0, len(l.metricsRecorders))
for _, mr := range l.metricsRecorders {
// Wrap the AsyncMetricsRecorder to intercept calls to RecordInt64Gauge
// and validate the labels.
wrappedCallback := func(recorder estats.AsyncMetricsRecorder) error {
wrappedRecorder := &asyncRecorderWrapper{
delegate: recorder,
descriptors: descriptorsMap,
}
return reporter.Report(wrappedRecorder)
}
unregisterFns = append(unregisterFns, mr.RegisterAsyncReporter(estats.AsyncMetricReporterFunc(wrappedCallback), metrics...))
}
// Wrap the cleanup function using the internal delegate.
// In production, this returns realCleanup as-is.
// In tests, the leak checker can swap this to track the registration lifetime.
return internal.AsyncReporterCleanupDelegate(defaultCleanUp(unregisterFns))
}
func defaultCleanUp(unregisterFns []func()) func() {
return func() {
for _, unregister := range unregisterFns {
unregister()
}
}
}
type asyncRecorderWrapper struct {
delegate estats.AsyncMetricsRecorder
descriptors map[*estats.MetricDescriptor]bool
}
// RecordIntAsync64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
func (w *asyncRecorderWrapper) RecordInt64AsyncGauge(handle *estats.Int64AsyncGaugeHandle, value int64, labels ...string) {
// Ensure only metrics for descriptors passed during callback registration
// are emitted.
d := handle.Descriptor()
if _, ok := w.descriptors[d]; !ok {
return
}
// Validate labels and delegate.
verifyLabels(d, labels...)
w.delegate.RecordInt64AsyncGauge(handle, value, labels...)
}

70
vendor/google.golang.org/grpc/internal/stats/stats.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"context"
"google.golang.org/grpc/stats"
)
type combinedHandler struct {
handlers []stats.Handler
}
// NewCombinedHandler combines multiple stats.Handlers into a single handler.
//
// It returns nil if no handlers are provided. If only one handler is
// provided, it is returned directly without wrapping.
func NewCombinedHandler(handlers ...stats.Handler) stats.Handler {
switch len(handlers) {
case 0:
return nil
case 1:
return handlers[0]
default:
return &combinedHandler{handlers: handlers}
}
}
func (ch *combinedHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
for _, h := range ch.handlers {
ctx = h.TagRPC(ctx, info)
}
return ctx
}
func (ch *combinedHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
for _, h := range ch.handlers {
h.HandleRPC(ctx, stats)
}
}
func (ch *combinedHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
for _, h := range ch.handlers {
ctx = h.TagConn(ctx, info)
}
return ctx
}
func (ch *combinedHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
for _, h := range ch.handlers {
h.HandleConn(ctx, stats)
}
}

View File

@@ -24,30 +24,34 @@ import (
"golang.org/x/net/http2"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)
// ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct {
*Stream // Embed for common stream functionality.
Stream // Embed for common stream functionality.
ct *http2Client
done chan struct{} // closed at the end of stream to unblock writers.
doneFunc func() // invoked at the end of stream.
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
headerChan chan struct{} // closed to indicate the end of header metadata.
header metadata.MD // the received header metadata
status *status.Status // the status error received from the server
// Non-pointer fields are at the end to optimize GC allocations.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool
header metadata.MD // the received header metadata
noHeaders bool // set if the client never received headers (set only after the stream is done).
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
status *status.Status // the status error received from the server
headerValid bool
noHeaders bool // set if the client never received headers (set only after the stream is done).
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream
unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream
statsHandler stats.Handler // nil for internal streams (e.g., health check, ORCA) where telemetry is not supported.
}
// Read reads an n byte message from the input stream.
@@ -142,3 +146,11 @@ func (s *ClientStream) TrailersOnly() bool {
func (s *ClientStream) Status() *status.Status {
return s.status
}
func (s *ClientStream) requestRead(n int) {
s.ct.adjustWindow(s, uint32(n))
}
func (s *ClientStream) updateWindow(n int) {
s.ct.updateWindow(s, uint32(n))
}

View File

@@ -24,16 +24,13 @@ import (
"fmt"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status"
)
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
@@ -147,11 +144,9 @@ type cleanupStream struct {
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
type earlyAbortStream struct {
httpStatus uint32
streamID uint32
contentSubtype string
status *status.Status
rst bool
streamID uint32
rst bool
hf []hpack.HeaderField // Pre-built header fields
}
func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
@@ -496,6 +491,16 @@ const (
serverSide
)
// maxWriteBufSize is the maximum length (number of elements) the cached
// writeBuf can grow to. The length depends on the number of buffers
// contained within the BufferSlice produced by the codec, which is
// generally small.
//
// If a writeBuf larger than this limit is required, it will be allocated
// and freed after use, rather than being cached. This avoids holding
// on to large amounts of memory.
const maxWriteBufSize = 64
// Loopy receives frames from the control buffer.
// Each frame is handled individually; most of the work done by loopy goes
// into handling data frames. Loopy maintains a queue of active streams, and each
@@ -530,6 +535,8 @@ type loopyWriter struct {
// Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error)
writeBuf [][]byte // cached slice to avoid heap allocations for calls to mem.Reader.Peek.
}
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
@@ -665,11 +672,10 @@ func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error {
func (l *loopyWriter) registerStreamHandler(h *registerStream) {
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
reader: mem.BufferSlice{}.Reader(),
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
l.estdStreams[h.streamID] = str
}
@@ -701,11 +707,10 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
}
// Case 2: Client wants to originate stream.
str := &outStream{
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
reader: mem.BufferSlice{}.Reader(),
id: h.streamID,
state: empty,
itl: &itemList{},
wq: h.wq,
}
return l.originateStream(str, h)
}
@@ -833,18 +838,7 @@ func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
if l.side == clientSide {
return errors.New("earlyAbortStream not handled on client")
}
// In case the caller forgets to set the http status, default to 200.
if eas.httpStatus == 0 {
eas.httpStatus = 200
}
headerFields := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
}
if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
if err := l.writeHeader(eas.streamID, true, eas.hf, nil); err != nil {
return err
}
if eas.rst {
@@ -948,11 +942,11 @@ func (l *loopyWriter) processData() (bool, error) {
if str == nil {
return true, nil
}
reader := str.reader
reader := &str.reader
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
if !dataItem.processing {
dataItem.processing = true
str.reader.Reset(dataItem.data)
reader.Reset(dataItem.data)
dataItem.data.Free()
}
// A data item is represented by a dataFrame, since it later translates into
@@ -964,11 +958,11 @@ func (l *loopyWriter) processData() (bool, error) {
if len(dataItem.h) == 0 && reader.Remaining() == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
if err := l.framer.writeData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err
}
str.itl.dequeue() // remove the empty data item from stream
_ = reader.Close()
reader.Close()
if str.itl.isEmpty() {
str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
@@ -1001,25 +995,20 @@ func (l *loopyWriter) processData() (bool, error) {
remainingBytes := len(dataItem.h) + reader.Remaining() - hSize - dSize
size := hSize + dSize
var buf *[]byte
if hSize != 0 && dSize == 0 {
buf = &dataItem.h
} else {
// Note: this is only necessary because the http2.Framer does not support
// partially writing a frame, so the sequence must be materialized into a buffer.
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed.
pool := l.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
l.writeBuf = l.writeBuf[:0]
if hSize > 0 {
l.writeBuf = append(l.writeBuf, dataItem.h[:hSize])
}
if dSize > 0 {
var err error
l.writeBuf, err = reader.Peek(dSize, l.writeBuf)
if err != nil {
// This must never happen since the reader must have at least dSize
// bytes.
// Log an error to fail tests.
l.logger.Errorf("unexpected error while reading Data frame payload: %v", err)
return false, err
}
buf = pool.Get(size)
defer pool.Put(buf)
copy((*buf)[:hSize], dataItem.h)
_, _ = reader.Read((*buf)[hSize:])
}
// Now that outgoing flow controls are checked we can replenish str's write quota
@@ -1032,7 +1021,14 @@ func (l *loopyWriter) processData() (bool, error) {
if dataItem.onEachWrite != nil {
dataItem.onEachWrite()
}
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil {
err := l.framer.writeData(dataItem.streamID, endStream, l.writeBuf)
reader.Discard(dSize)
if cap(l.writeBuf) > maxWriteBufSize {
l.writeBuf = nil
} else {
clear(l.writeBuf)
}
if err != nil {
return false, err
}
str.bytesOutStanding += size
@@ -1040,7 +1036,7 @@ func (l *loopyWriter) processData() (bool, error) {
dataItem.h = dataItem.h[hSize:]
if remainingBytes == 0 { // All the data from that message was written out.
_ = reader.Close()
reader.Close()
str.itl.dequeue()
}
if str.itl.isEmpty() {

View File

@@ -28,7 +28,7 @@ import (
// writeQuota is a soft limit on the amount of data a stream can
// schedule before some of it is written out.
type writeQuota struct {
quota int32
_ noCopy
// get waits on read from when quota goes less than or equal to zero.
// replenish writes on it when quota goes positive again.
ch chan struct{}
@@ -38,16 +38,17 @@ type writeQuota struct {
// It is implemented as a field so that it can be updated
// by tests.
replenish func(n int)
quota int32
}
func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota {
w := &writeQuota{
quota: sz,
ch: make(chan struct{}, 1),
done: done,
}
// init allows a writeQuota to be initialized in-place, which is useful for
// resetting a buffer or for avoiding a heap allocation when the buffer is
// embedded in another struct.
func (w *writeQuota) init(sz int32, done <-chan struct{}) {
w.quota = sz
w.ch = make(chan struct{}, 1)
w.done = done
w.replenish = w.realReplenish
return w
}
func (w *writeQuota) get(sz int32) error {
@@ -67,9 +68,9 @@ func (w *writeQuota) get(sz int32) error {
func (w *writeQuota) realReplenish(n int) {
sz := int32(n)
a := atomic.AddInt32(&w.quota, sz)
b := a - sz
if b <= 0 && a > 0 {
newQuota := atomic.AddInt32(&w.quota, sz)
previousQuota := newQuota - sz
if previousQuota <= 0 && newQuota > 0 {
select {
case w.ch <- struct{}{}:
default:

View File

@@ -50,7 +50,7 @@ import (
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
if r.Method != http.MethodPost {
w.Header().Set("Allow", http.MethodPost)
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
@@ -170,7 +170,7 @@ type serverHandlerTransport struct {
// TODO make sure this is consistent across handler_server and http2_server
contentSubtype string
stats []stats.Handler
stats stats.Handler
logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
@@ -274,14 +274,14 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
}
})
if err == nil { // transport has not been closed
if err == nil && ht.stats != nil { // transport has not been closed
// Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here.
for _, sh := range ht.stats {
sh.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}
s.hdrMu.Lock()
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
s.hdrMu.Unlock()
}
ht.Close(errors.New("finished writing status"))
return err
@@ -372,19 +372,23 @@ func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) e
ht.rw.(http.Flusher).Flush()
})
if err == nil {
for _, sh := range ht.stats {
// Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here.
sh.HandleRPC(s.Context(), &stats.OutHeader{
Header: md.Copy(),
Compression: s.sendCompress,
})
}
if err == nil && ht.stats != nil {
// Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here.
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{
Header: md.Copy(),
Compression: s.sendCompress,
})
}
return err
}
func (ht *serverHandlerTransport) adjustWindow(*ServerStream, uint32) {
}
func (ht *serverHandlerTransport) updateWindow(*ServerStream, uint32) {
}
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var cancel context.CancelFunc
@@ -409,11 +413,9 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req
s := &ServerStream{
Stream: &Stream{
Stream: Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
buf: newRecvBuffer(),
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
@@ -422,9 +424,11 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
st: ht,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
s.Stream.buf.init()
s.readRequester = s
s.trReader = transportReader{
reader: recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: &s.buf},
windowHandler: s,
}
// readerDone is closed when the Body.Read-ing goroutine exits.

View File

@@ -44,6 +44,7 @@ import (
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/proxyattributes"
istats "google.golang.org/grpc/internal/stats"
istatus "google.golang.org/grpc/internal/status"
isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype"
@@ -105,7 +106,7 @@ type http2Client struct {
kp keepalive.ClientParameters
keepaliveEnabled bool
statsHandlers []stats.Handler
statsHandler stats.Handler
initialWindowSize int32
@@ -335,14 +336,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
writerDone: make(chan struct{}),
goAway: make(chan struct{}),
keepaliveDone: make(chan struct{}),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize, opts.BufferPool),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
activeStreams: make(map[uint32]*ClientStream),
isSecure: isSecure,
perRPCCreds: perRPCCreds,
kp: kp,
statsHandlers: opts.StatsHandlers,
statsHandler: istats.NewCombinedHandler(opts.StatsHandlers...),
initialWindowSize: initialWindowSize,
nextID: 1,
maxConcurrentStreams: defaultMaxStreamsClient,
@@ -369,7 +370,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
})
t.logger = prefixLoggerForClientTransport(t)
// Add peer information to the http2client context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())
t.ctx = peer.NewContext(t.ctx, t.Peer())
if md, ok := addr.Metadata.(*metadata.MD); ok {
t.md = *md
@@ -386,15 +387,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
updateFlowControl: t.updateFlowControl,
}
}
for _, sh := range t.statsHandlers {
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
if t.statsHandler != nil {
t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{
t.statsHandler.HandleConn(t.ctx, &stats.ConnBegin{
Client: true,
}
sh.HandleConn(t.ctx, connBegin)
})
}
if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu)
@@ -478,45 +478,40 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream {
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) *ClientStream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &ClientStream{
Stream: &Stream{
Stream: Stream{
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
contentSubtype: callHdr.ContentSubtype,
},
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
}
s.wq = newWriteQuota(defaultWriteQuota, s.done)
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
statsHandler: handler,
}
s.Stream.buf.init()
s.Stream.wq.init(defaultWriteQuota, s.done)
s.readRequester = s
// The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: s.buf,
closeStream: func(err error) {
s.Close(err)
},
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
s.trReader = transportReader{
reader: recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: &s.buf,
clientStream: s,
},
windowHandler: s,
}
return s
}
func (t *http2Client) getPeer() *peer.Peer {
func (t *http2Client) Peer() *peer.Peer {
return &peer.Peer{
Addr: t.remoteAddr,
AuthInfo: t.authInfo, // Can be nil
@@ -556,6 +551,22 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
// Make the slice of certain predictable size to reduce allocations made by append.
hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
hfLen += len(authData) + len(callAuthData)
registeredCompressors := t.registeredCompressors
if callHdr.AcceptedCompressors != nil {
registeredCompressors = *callHdr.AcceptedCompressors
}
if callHdr.PreviousAttempts > 0 {
hfLen++
}
if callHdr.SendCompress != "" {
hfLen++
}
if registeredCompressors != "" {
hfLen++
}
if _, ok := ctx.Deadline(); ok {
hfLen++
}
headerFields := make([]hpack.HeaderField, 0, hfLen)
headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
@@ -568,7 +579,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)})
}
registeredCompressors := t.registeredCompressors
if callHdr.SendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
// Include the outgoing compressor name when compressor is not registered
@@ -735,8 +745,8 @@ func (e NewStreamError) Error() string {
// NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.Peer())
// ServerName field of the resolver returned address takes precedence over
// Host field of CallHdr to determine the :authority header. This is because,
@@ -772,7 +782,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
}
s := t.newStream(ctx, callHdr)
s := t.newStream(ctx, callHdr, handler)
cleanup := func(err error) {
if s.swapState(streamDone) == streamDone {
// If it was already done, return.
@@ -811,7 +821,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
return nil
},
onOrphaned: cleanup,
wq: s.wq,
wq: &s.wq,
}
firstTry := true
var ch chan struct{}
@@ -842,7 +852,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
transportDrainRequired = t.nextID > MaxStreamID
s.id = hdr.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
s.fc = inFlow{limit: uint32(t.initialWindowSize)}
t.activeStreams[s.id] = s
t.mu.Unlock()
@@ -893,27 +903,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS
return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true}
}
}
if len(t.statsHandlers) != 0 {
if s.statsHandler != nil {
header, ok := metadata.FromOutgoingContext(ctx)
if ok {
header.Set("user-agent", t.userAgent)
} else {
header = metadata.Pairs("user-agent", t.userAgent)
}
for _, sh := range t.statsHandlers {
// Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here.
// Note: Creating a new stats object to prevent pollution.
outHeader := &stats.OutHeader{
Client: true,
FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: callHdr.SendCompress,
Header: header,
}
sh.HandleRPC(s.ctx, outHeader)
}
// Note: The header fields are compressed with hpack after this call returns.
// No WireLength field is set here.
s.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{
Client: true,
FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: callHdr.SendCompress,
Header: header,
})
}
if transportDrainRequired {
if t.logger.V(logLevel) {
@@ -990,6 +996,9 @@ func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode
// accessed anymore.
func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
// For background on the deadline value chosen here, see
// https://github.com/grpc/grpc-go/issues/8425#issuecomment-3057938248 .
t.conn.SetReadDeadline(time.Now().Add(time.Second))
t.mu.Lock()
// Make sure we only close once.
if t.state == closing {
@@ -1051,11 +1060,10 @@ func (t *http2Client) Close(err error) {
for _, s := range streams {
t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false)
}
for _, sh := range t.statsHandlers {
connEnd := &stats.ConnEnd{
if t.statsHandler != nil {
t.statsHandler.HandleConn(t.ctx, &stats.ConnEnd{
Client: true,
}
sh.HandleConn(t.ctx, connEnd)
})
}
}
@@ -1166,7 +1174,7 @@ func (t *http2Client) updateFlowControl(n uint32) {
})
}
func (t *http2Client) handleData(f *http2.DataFrame) {
func (t *http2Client) handleData(f *parsedDataFrame) {
size := f.Header().Length
var sendBDPPing bool
if t.bdpEst != nil {
@@ -1210,22 +1218,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
return
}
dataLen := f.data.Len()
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
if w := s.fc.onRead(size - uint32(dataLen)); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
}
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
if dataLen > 0 {
f.data.Ref()
s.write(recvMsg{buffer: f.data})
}
}
// The server has closed the stream without sending trailers. Record that
@@ -1465,17 +1466,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
contentTypeErr = "malformed header: missing HTTP content-type"
grpcMessage string
recvCompress string
httpStatusCode *int
httpStatusErr string
rawStatusCode = codes.Unknown
// the code from the grpc-status header, if present
grpcStatusCode = codes.Unknown
// headerError is set if an error is encountered while parsing the headers
headerError string
httpStatus string
)
if initialHeader {
httpStatusErr = "malformed header: missing HTTP status"
}
for _, hf := range frame.Fields {
switch hf.Name {
case "content-type":
@@ -1491,35 +1489,15 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
case "grpc-status":
code, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err))
se := status.New(codes.Unknown, fmt.Sprintf("transport: malformed grpc-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
rawStatusCode = codes.Code(uint32(code))
grpcStatusCode = codes.Code(uint32(code))
case "grpc-message":
grpcMessage = decodeGrpcMessage(hf.Value)
case ":status":
if hf.Value == "200" {
httpStatusErr = ""
statusCode := 200
httpStatusCode = &statusCode
break
}
c, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
statusCode := int(c)
httpStatusCode = &statusCode
httpStatusErr = fmt.Sprintf(
"unexpected HTTP status code received from server: %d (%s)",
statusCode,
http.StatusText(statusCode),
)
httpStatus = hf.Value
default:
if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) {
break
@@ -1534,25 +1512,52 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}
}
if !isGRPC || httpStatusErr != "" {
var code = codes.Internal // when header does not include HTTP status, return INTERNAL
if httpStatusCode != nil {
// If a non-gRPC response is received, then evaluate the HTTP status to
// process the response and close the stream.
// In case http status doesn't provide any error information (status : 200),
// then evalute response code to be Unknown.
if !isGRPC {
var grpcErrorCode = codes.Internal
if httpStatus == "" {
httpStatusErr = "malformed header: missing HTTP status"
} else {
// Parse the status codes (e.g. "200", 404").
statusCode, err := strconv.Atoi(httpStatus)
if err != nil {
se := status.New(grpcErrorCode, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
if statusCode >= 100 && statusCode < 200 {
if endStream {
se := status.New(codes.Internal, fmt.Sprintf(
"protocol error: informational header with status code %d must not have END_STREAM set", statusCode))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
}
// In case of informational headers, return.
return
}
httpStatusErr = fmt.Sprintf(
"unexpected HTTP status code received from server: %d (%s)",
statusCode,
http.StatusText(statusCode),
)
var ok bool
code, ok = HTTPStatusConvTab[*httpStatusCode]
grpcErrorCode, ok = HTTPStatusConvTab[statusCode]
if !ok {
code = codes.Unknown
grpcErrorCode = codes.Unknown
}
}
var errs []string
if httpStatusErr != "" {
errs = append(errs, httpStatusErr)
}
if contentTypeErr != "" {
errs = append(errs, contentTypeErr)
}
// Verify the HTTP response is a 200.
se := status.New(code, strings.Join(errs, "; "))
se := status.New(grpcErrorCode, strings.Join(errs, "; "))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
@@ -1583,22 +1588,20 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}
}
for _, sh := range t.statsHandlers {
if s.statsHandler != nil {
if !endStream {
inHeader := &stats.InHeader{
s.statsHandler.HandleRPC(s.ctx, &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
Header: metadata.MD(mdata).Copy(),
Compression: s.recvCompress,
}
sh.HandleRPC(s.ctx, inHeader)
})
} else {
inTrailer := &stats.InTrailer{
s.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{
Client: true,
WireLength: int(frame.Header().Length),
Trailer: metadata.MD(mdata).Copy(),
}
sh.HandleRPC(s.ctx, inTrailer)
})
}
}
@@ -1606,7 +1609,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return
}
status := istatus.NewWithProto(rawStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader])
status := istatus.NewWithProto(grpcStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader])
// If client received END_STREAM from server while stream was still active,
// send RST_STREAM.
@@ -1653,7 +1656,7 @@ func (t *http2Client) reader(errCh chan<- error) {
// loop to keep reading incoming messages on this transport.
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
frame, err := t.framer.readFrame()
if t.keepaliveEnabled {
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
}
@@ -1668,7 +1671,7 @@ func (t *http2Client) reader(errCh chan<- error) {
if s != nil {
// use error detail to provide better err message
code := http2ErrConvTab[se.Code]
errorDetail := t.framer.fr.ErrorDetail()
errorDetail := t.framer.errorDetail()
var msg string
if errorDetail != nil {
msg = errorDetail.Error()
@@ -1686,8 +1689,9 @@ func (t *http2Client) reader(errCh chan<- error) {
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
t.operateHeaders(frame)
case *http2.DataFrame:
case *parsedDataFrame:
t.handleData(frame)
frame.data.Free()
case *http2.RSTStreamFrame:
t.handleRSTStream(frame)
case *http2.SettingsFrame:
@@ -1807,8 +1811,6 @@ func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics {
}
}
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) incrMsgSent() {
if channelz.IsOn() {
t.channelz.SocketMetrics.MessagesSent.Add(1)

View File

@@ -35,6 +35,8 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil"
@@ -42,7 +44,6 @@ import (
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -86,7 +87,7 @@ type http2Server struct {
// updates, reset streams, and various settings) to the controller.
controlBuf *controlBuffer
fc *trInFlow
stats []stats.Handler
stats stats.Handler
// Keepalive and max-age parameters for the server.
kp keepalive.ServerParameters
// Keepalive enforcement policy.
@@ -168,7 +169,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
if config.MaxHeaderListSize != nil {
maxHeaderListSize = *config.MaxHeaderListSize
}
framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize)
framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize, config.BufferPool)
// Send initial settings as connection preface to client.
isettings := []http2.Setting{{
ID: http2.SettingMaxFrameSize,
@@ -260,7 +261,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
activeStreams: make(map[uint32]*ServerStream),
stats: config.StatsHandlers,
stats: config.StatsHandler,
kp: kp,
idle: time.Now(),
kep: kep,
@@ -390,16 +391,15 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
}
t.maxStreamID = streamID
buf := newRecvBuffer()
s := &ServerStream{
Stream: &Stream{
id: streamID,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
Stream: Stream{
id: streamID,
fc: inFlow{limit: uint32(t.initialWindowSize)},
},
st: t,
headerWireLength: int(frame.Header().Length),
}
s.Stream.buf.init()
var (
// if false, content-type was missing or invalid
isGRPC = false
@@ -479,13 +479,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusBadRequest, !frame.StreamEnded())
return nil
}
@@ -499,23 +493,11 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
return nil
}
if !isGRPC {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusUnsupportedMediaType,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), http.StatusUnsupportedMediaType, !frame.StreamEnded())
return nil
}
if headerError != nil {
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: headerError,
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, headerError, http.StatusBadRequest, !frame.StreamEnded())
return nil
}
@@ -569,13 +551,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if t.logger.V(logLevel) {
t.logger.Infof("Aborting the stream early: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusMethodNotAllowed,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusMethodNotAllowed, !frame.StreamEnded())
s.cancel()
return nil
}
@@ -590,27 +566,16 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
if !ok {
stat = status.New(codes.PermissionDenied, err.Error())
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: stat,
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(s.id, s.contentSubtype, stat, http.StatusOK, !frame.StreamEnded())
return nil
}
}
if s.ctx.Err() != nil {
t.mu.Unlock()
st := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
// Early abort in case the timeout was zero or so low it already fired.
t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusOK,
streamID: s.id,
contentSubtype: s.contentSubtype,
status: status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
rst: !frame.StreamEnded(),
})
t.writeEarlyAbort(s.id, s.contentSubtype, st, http.StatusOK, !frame.StreamEnded())
return nil
}
@@ -640,25 +605,21 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
t.channelz.SocketMetrics.StreamsStarted.Add(1)
t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano())
}
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.readRequester = s
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
reader: &recvBufferReader{
s.Stream.wq.init(defaultWriteQuota, s.ctxDone)
s.trReader = transportReader{
reader: recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
recv: &s.buf,
},
windowHandler: s,
}
// Register the stream with loopy.
t.controlBuf.put(&registerStream{
streamID: s.id,
wq: s.wq,
wq: &s.wq,
})
handle(s)
return nil
@@ -674,7 +635,7 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre
}()
for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame()
frame, err := t.framer.readFrame()
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
if err != nil {
if se, ok := err.(http2.StreamError); ok {
@@ -711,8 +672,9 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre
})
continue
}
case *http2.DataFrame:
case *parsedDataFrame:
t.handleData(frame)
frame.data.Free()
case *http2.RSTStreamFrame:
t.handleRSTStream(frame)
case *http2.SettingsFrame:
@@ -792,7 +754,7 @@ func (t *http2Server) updateFlowControl(n uint32) {
}
func (t *http2Server) handleData(f *http2.DataFrame) {
func (t *http2Server) handleData(f *parsedDataFrame) {
size := f.Header().Length
var sendBDPPing bool
if t.bdpEst != nil {
@@ -837,22 +799,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
t.closeStream(s, true, http2.ErrCodeFlowControl, false)
return
}
dataLen := f.data.Len()
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
if w := s.fc.onRead(size - uint32(dataLen)); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
}
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
pool := t.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
if dataLen > 0 {
f.data.Ref()
s.write(recvMsg{buffer: f.data})
}
}
if f.StreamEnded() {
@@ -979,13 +934,12 @@ func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD)
return headerFields
}
func (t *http2Server) checkForHeaderListSize(it any) bool {
func (t *http2Server) checkForHeaderListSize(hf []hpack.HeaderField) bool {
if t.maxSendHeaderListSize == nil {
return true
}
hdrFrame := it.(*headerFrame)
var sz int64
for _, f := range hdrFrame.hf {
for _, f := range hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
if t.logger.V(logLevel) {
t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize)
@@ -996,6 +950,42 @@ func (t *http2Server) checkForHeaderListSize(it any) bool {
return true
}
// writeEarlyAbort sends an early abort response with the given HTTP status and
// gRPC status. If the header list size exceeds the peer's limit, it sends a
// RST_STREAM instead.
func (t *http2Server) writeEarlyAbort(streamID uint32, contentSubtype string, stat *status.Status, httpStatus uint32, rst bool) {
hf := []hpack.HeaderField{
{Name: ":status", Value: strconv.Itoa(int(httpStatus))},
{Name: "content-type", Value: grpcutil.ContentType(contentSubtype)},
{Name: "grpc-status", Value: strconv.Itoa(int(stat.Code()))},
{Name: "grpc-message", Value: encodeGrpcMessage(stat.Message())},
}
if p := istatus.RawStatusProto(stat); len(p.GetDetails()) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err)
}
if err == nil {
hf = append(hf, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)})
}
}
success, _ := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(hf)
}, &earlyAbortStream{
streamID: streamID,
rst: rst,
hf: hf,
})
if !success {
t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeInternal,
onWrite: func() {},
})
}
}
func (t *http2Server) streamContextErr(s *ServerStream) error {
select {
case <-t.done:
@@ -1051,7 +1041,7 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
endStream: false,
onWrite: t.setResetPingStrikes,
}
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf)
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf.hf) }, hf)
if !success {
if err != nil {
return err
@@ -1059,14 +1049,13 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
t.closeStream(s, true, http2.ErrCodeInternal, false)
return ErrHeaderListSizeLimitViolation
}
for _, sh := range t.stats {
if t.stats != nil {
// Note: Headers are compressed with hpack after this call returns.
// No WireLength field is set here.
outHeader := &stats.OutHeader{
t.stats.HandleRPC(s.Context(), &stats.OutHeader{
Header: s.header.Copy(),
Compression: s.sendCompress,
}
sh.HandleRPC(s.Context(), outHeader)
})
}
return nil
}
@@ -1122,7 +1111,7 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
}
success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader)
return t.checkForHeaderListSize(trailingHeader.hf)
}, nil)
if !success {
if err != nil {
@@ -1134,10 +1123,10 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error {
// Send a RST_STREAM after the trailers if the client has not already half-closed.
rst := s.getState() == streamActive
t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true)
for _, sh := range t.stats {
if t.stats != nil {
// Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here.
sh.HandleRPC(s.Context(), &stats.OutTrailer{
t.stats.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}
@@ -1305,7 +1294,8 @@ func (t *http2Server) Close(err error) {
// deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok {
_, isActive := t.activeStreams[s.id]
if isActive {
delete(t.activeStreams, s.id)
if len(t.activeStreams) == 0 {
t.idle = time.Now()
@@ -1313,7 +1303,7 @@ func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
}
t.mu.Unlock()
if channelz.IsOn() {
if isActive && channelz.IsOn() {
if eosReceived {
t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else {
@@ -1353,10 +1343,10 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo
// called to interrupt the potential blocking on other goroutines.
s.cancel()
oldState := s.swapState(streamDone)
if oldState == streamDone {
return
}
// We can't return early even if the stream's state is "done" as the state
// might have been set by the `finishStream` method. Deleting the stream via
// `finishStream` can get blocked on flow control.
s.swapState(streamDone)
t.deleteStream(s, eosReceived)
t.controlBuf.put(&cleanupStream{

View File

@@ -25,7 +25,6 @@ import (
"fmt"
"io"
"math"
"net"
"net/http"
"net/url"
"strconv"
@@ -37,6 +36,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
)
const (
@@ -300,11 +300,11 @@ type bufWriter struct {
buf []byte
offset int
batchSize int
conn net.Conn
conn io.Writer
err error
}
func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter {
w := &bufWriter{
batchSize: batchSize,
conn: conn,
@@ -388,15 +388,29 @@ func toIOError(err error) error {
return ioError{error: err}
}
type parsedDataFrame struct {
http2.FrameHeader
data mem.Buffer
}
func (df *parsedDataFrame) StreamEnded() bool {
return df.FrameHeader.Flags.Has(http2.FlagDataEndStream)
}
type framer struct {
writer *bufWriter
fr *http2.Framer
writer *bufWriter
fr *http2.Framer
headerBuf []byte // cached slice for framer headers to reduce heap allocs.
reader io.Reader
dataFrame parsedDataFrame // Cached data frame to avoid heap allocations.
pool mem.BufferPool
errDetail error
}
var writeBufferPoolMap = make(map[int]*sync.Pool)
var writeBufferMutex sync.Mutex
func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer {
func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer {
if writeBufferSize < 0 {
writeBufferSize = 0
}
@@ -412,6 +426,8 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
f := &framer{
writer: w,
fr: http2.NewFramer(w, r),
reader: r,
pool: memPool,
}
f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
// Opt-in to Frame reuse API on framer to reduce garbage.
@@ -422,6 +438,146 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
return f
}
// writeData writes a DATA frame.
//
// It is the caller's responsibility not to violate the maximum frame size.
func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error {
var flags http2.Flags
if endStream {
flags = http2.FlagDataEndStream
}
length := uint32(0)
for _, d := range data {
length += uint32(len(d))
}
// TODO: Replace the header write with the framer API being added in
// https://github.com/golang/go/issues/66655.
f.headerBuf = append(f.headerBuf[:0],
byte(length>>16),
byte(length>>8),
byte(length),
byte(http2.FrameData),
byte(flags),
byte(streamID>>24),
byte(streamID>>16),
byte(streamID>>8),
byte(streamID))
if _, err := f.writer.Write(f.headerBuf); err != nil {
return err
}
for _, d := range data {
if _, err := f.writer.Write(d); err != nil {
return err
}
}
return nil
}
// readFrame reads a single frame. The returned Frame is only valid
// until the next call to readFrame.
func (f *framer) readFrame() (any, error) {
f.errDetail = nil
fh, err := f.fr.ReadFrameHeader()
if err != nil {
f.errDetail = f.fr.ErrorDetail()
return nil, err
}
// Read the data frame directly from the underlying io.Reader to avoid
// copies.
if fh.Type == http2.FrameData {
err = f.readDataFrame(fh)
return &f.dataFrame, err
}
fr, err := f.fr.ReadFrameForHeader(fh)
if err != nil {
f.errDetail = f.fr.ErrorDetail()
return nil, err
}
return fr, err
}
// errorDetail returns a more detailed error of the last error
// returned by framer.readFrame. For instance, if readFrame
// returns a StreamError with code PROTOCOL_ERROR, errorDetail
// will say exactly what was invalid. errorDetail is not guaranteed
// to return a non-nil value.
// errorDetail is reset after the next call to readFrame.
func (f *framer) errorDetail() error {
return f.errDetail
}
func (f *framer) readDataFrame(fh http2.FrameHeader) (err error) {
if fh.StreamID == 0 {
// DATA frames MUST be associated with a stream. If a
// DATA frame is received whose stream identifier
// field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
f.errDetail = errors.New("DATA frame with stream ID 0")
return http2.ConnectionError(http2.ErrCodeProtocol)
}
// Converting a *[]byte to a mem.SliceBuffer incurs a heap allocation. This
// conversion is performed by mem.NewBuffer. To avoid the extra allocation
// a []byte is allocated directly if required and cast to a mem.SliceBuffer.
var buf []byte
// poolHandle is the pointer returned by the buffer pool (if it's used.).
var poolHandle *[]byte
useBufferPool := !mem.IsBelowBufferPoolingThreshold(int(fh.Length))
if useBufferPool {
poolHandle = f.pool.Get(int(fh.Length))
buf = *poolHandle
defer func() {
if err != nil {
f.pool.Put(poolHandle)
}
}()
} else {
buf = make([]byte, int(fh.Length))
}
if fh.Flags.Has(http2.FlagDataPadded) {
if fh.Length == 0 {
return io.ErrUnexpectedEOF
}
// This initial 1-byte read can be inefficient for unbuffered readers,
// but it allows the rest of the payload to be read directly to the
// start of the destination slice. This makes it easy to return the
// original slice back to the buffer pool.
if _, err := io.ReadFull(f.reader, buf[:1]); err != nil {
return err
}
padSize := buf[0]
buf = buf[:len(buf)-1]
if int(padSize) > len(buf) {
// If the length of the padding is greater than the
// length of the frame payload, the recipient MUST
// treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610
f.errDetail = errors.New("pad size larger than data payload")
return http2.ConnectionError(http2.ErrCodeProtocol)
}
if _, err := io.ReadFull(f.reader, buf); err != nil {
return err
}
buf = buf[:len(buf)-int(padSize)]
} else if _, err := io.ReadFull(f.reader, buf); err != nil {
return err
}
f.dataFrame.FrameHeader = fh
if useBufferPool {
// Update the handle to point to the (potentially re-sliced) buf.
*poolHandle = buf
f.dataFrame.data = mem.NewBuffer(poolHandle, f.pool)
} else {
f.dataFrame.data = mem.SliceBuffer(buf)
}
return nil
}
func (df *parsedDataFrame) Header() http2.FrameHeader {
return df.FrameHeader
}
func getWriteBufferPool(size int) *sync.Pool {
writeBufferMutex.Lock()
defer writeBufferMutex.Unlock()

View File

@@ -32,7 +32,7 @@ import (
// ServerStream implements streaming functionality for a gRPC server.
type ServerStream struct {
*Stream // Embed for common stream functionality.
Stream // Embed for common stream functionality.
st internalServerTransport
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
@@ -43,12 +43,13 @@ type ServerStream struct {
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client.
clientAdvertisedCompressors string
headerWireLength int
// hdrMu protects outgoing header and trailer metadata.
hdrMu sync.Mutex
header metadata.MD // the outgoing header metadata. Updated by WriteHeader.
headerSent atomic.Bool // atomically set when the headers are sent out.
headerWireLength int
}
// Read reads an n byte message from the input stream.
@@ -178,3 +179,11 @@ func (s *ServerStream) SetTrailer(md metadata.MD) error {
s.hdrMu.Unlock()
return nil
}
func (s *ServerStream) requestRead(n int) {
s.st.adjustWindow(s, uint32(n))
}
func (s *ServerStream) updateWindow(n int) {
s.st.updateWindow(s, uint32(n))
}

View File

@@ -68,11 +68,11 @@ type recvBuffer struct {
err error
}
func newRecvBuffer() *recvBuffer {
b := &recvBuffer{
c: make(chan recvMsg, 1),
}
return b
// init allows a recvBuffer to be initialized in-place, which is useful
// for resetting a buffer or for avoiding a heap allocation when the buffer
// is embedded in another struct.
func (b *recvBuffer) init() {
b.c = make(chan recvMsg, 1)
}
func (b *recvBuffer) put(r recvMsg) {
@@ -123,12 +123,13 @@ func (b *recvBuffer) get() <-chan recvMsg {
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata.
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last mem.Buffer // Stores the remaining data in the previous calls.
err error
_ noCopy
clientStream *ClientStream // The client transport stream is closed with a status representing ctx.Err() and nil trailer metadata.
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last mem.Buffer // Stores the remaining data in the previous calls.
err error
}
func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
@@ -139,7 +140,7 @@ func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) {
n, r.last = mem.ReadUnsafe(header, r.last)
return n, nil
}
if r.closeStream != nil {
if r.clientStream != nil {
n, r.err = r.readMessageHeaderClient(header)
} else {
n, r.err = r.readMessageHeader(header)
@@ -164,7 +165,7 @@ func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
}
return buf, nil
}
if r.closeStream != nil {
if r.clientStream != nil {
buf, r.err = r.readClient(n)
} else {
buf, r.err = r.read(n)
@@ -209,7 +210,7 @@ func (r *recvBufferReader) readMessageHeaderClient(header []byte) (n int, err er
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readMessageHeaderAdditional(m, header)
case m := <-r.recv.get():
@@ -236,7 +237,7 @@ func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
r.clientStream.Close(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, n)
case m := <-r.recv.get():
@@ -285,27 +286,32 @@ const (
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
ctx context.Context // the associated context of the stream
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
trReader *transportReader
fc *inFlow
wq *writeQuota
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
state streamState
readRequester readRequester
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
trailer metadata.MD // the key-value map of trailer metadata.
// Non-pointer fields are at the end to optimize GC performance.
state streamState
id uint32
buf recvBuffer
trReader transportReader
fc inFlow
wq writeQuota
}
// readRequester is used to state application's intentions to read data. This
// is used to adjust flow control, if needed.
type readRequester interface {
requestRead(int)
}
func (s *Stream) swapState(st streamState) streamState {
@@ -355,7 +361,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) {
if er := s.trReader.er; er != nil {
return er
}
s.requestRead(len(header))
s.readRequester.requestRead(len(header))
for len(header) != 0 {
n, err := s.trReader.ReadMessageHeader(header)
header = header[n:]
@@ -372,13 +378,29 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) {
return nil
}
// ceil returns the ceil after dividing the numerator and denominator while
// avoiding integer overflows.
func ceil(numerator, denominator int) int {
if numerator == 0 {
return 0
}
return (numerator-1)/denominator + 1
}
// Read reads n bytes from the wire for this stream.
func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil {
return nil, er
}
s.requestRead(n)
// gRPC Go accepts data frames with a maximum length of 16KB. Larger
// messages must be split into multiple frames. We pre-allocate the
// buffer to avoid resizing during the read loop, but cap the initial
// capacity to 128 frames (2MB) to prevent over-allocation or panics
// when reading extremely large streams.
allocCap := min(ceil(n, http2MaxFrameLen), 128)
data = make(mem.BufferSlice, 0, allocCap)
s.readRequester.requestRead(n)
for n != 0 {
buf, err := s.trReader.Read(n)
var bufLen int
@@ -401,16 +423,34 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) {
return data, nil
}
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct {
}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// transportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
type transportReader struct {
reader *recvBufferReader
_ noCopy
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
windowHandler windowHandler
er error
reader recvBufferReader
}
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
type windowHandler interface {
updateWindow(int)
}
func (t *transportReader) ReadMessageHeader(header []byte) (int, error) {
@@ -419,7 +459,7 @@ func (t *transportReader) ReadMessageHeader(header []byte) (int, error) {
t.er = err
return 0, err
}
t.windowHandler(n)
t.windowHandler.updateWindow(n)
return n, nil
}
@@ -429,7 +469,7 @@ func (t *transportReader) Read(n int) (mem.Buffer, error) {
t.er = err
return buf, err
}
t.windowHandler(buf.Len())
t.windowHandler.updateWindow(buf.Len())
return buf, nil
}
@@ -454,7 +494,7 @@ type ServerConfig struct {
ConnectionTimeout time.Duration
Credentials credentials.TransportCredentials
InTapHandle tap.ServerInHandle
StatsHandlers []stats.Handler
StatsHandler stats.Handler
KeepaliveParams keepalive.ServerParameters
KeepalivePolicy keepalive.EnforcementPolicy
InitialWindowSize int32
@@ -529,6 +569,12 @@ type CallHdr struct {
// outbound message.
SendCompress string
// AcceptedCompressors overrides the grpc-accept-encoding header for this
// call. When nil, the transport advertises the default set of registered
// compressors. A non-nil pointer overrides that value (including the empty
// string to advertise none).
AcceptedCompressors *string
// Creds specifies credentials.PerRPCCredentials for a call.
Creds credentials.PerRPCCredentials
@@ -544,9 +590,14 @@ type CallHdr struct {
DoneFunc func() // called when the stream is finished
// Authority is used to explicitly override the `:authority` header. If set,
// this value takes precedence over the Host field and will be used as the
// value for the `:authority` header.
// Authority is used to explicitly override the `:authority` header.
//
// This value comes from one of two sources:
// 1. The `CallAuthority` call option, if specified by the user.
// 2. An override provided by the LB picker (e.g. xDS authority rewriting).
//
// The `CallAuthority` call option always takes precedence over the LB
// picker override.
Authority string
}
@@ -566,7 +617,7 @@ type ClientTransport interface {
GracefulClose()
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error)
NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error)
// Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor
@@ -584,8 +635,9 @@ type ClientTransport interface {
// with a human readable string with debug info.
GetGoAwayReason() (GoAwayReason, string)
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// Peer returns information about the peer associated with the Transport.
// The returned information includes authentication and network address details.
Peer() *peer.Peer
}
// ServerTransport is the common interface for all gRPC server-side transport
@@ -615,6 +667,8 @@ type internalServerTransport interface {
write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error
writeStatus(s *ServerStream, st *status.Status) error
incrMsgRecv()
adjustWindow(s *ServerStream, n uint32)
updateWindow(s *ServerStream, n uint32)
}
// connectionErrorf creates an ConnectionError with the specified error description.

View File

@@ -32,12 +32,17 @@ type BufferPool interface {
Get(length int) *[]byte
// Put returns a buffer to the pool.
//
// The provided pointer must hold a prefix of the buffer obtained via
// BufferPool.Get to ensure the buffer's entire capacity can be re-used.
Put(*[]byte)
}
const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2.
var defaultBufferPoolSizes = []int{
256,
4 << 10, // 4KB (go page size)
goPageSize,
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB
@@ -48,7 +53,7 @@ var defaultBufferPool BufferPool
func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
internal.SetDefaultBufferPool = func(pool BufferPool) {
defaultBufferPool = pool
}
@@ -118,7 +123,11 @@ type sizedBufferPool struct {
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf := p.pool.Get().(*[]byte)
buf, ok := p.pool.Get().(*[]byte)
if !ok {
buf := make([]byte, size, p.defaultSize)
return &buf
}
b := *buf
clear(b[:cap(b)])
*buf = b[:size]
@@ -137,12 +146,6 @@ func (p *sizedBufferPool) Put(buf *[]byte) {
func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{
pool: sync.Pool{
New: func() any {
buf := make([]byte, size)
return &buf
},
},
defaultSize: size,
}
}
@@ -160,6 +163,7 @@ type simpleBufferPool struct {
func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
clear((*bs)[:cap(*bs)])
*bs = (*bs)[:size]
return bs
}
@@ -170,7 +174,14 @@ func (p *simpleBufferPool) Get(size int) *[]byte {
p.pool.Put(bs)
}
b := make([]byte, size)
// If we're going to allocate, round up to the nearest page. This way if
// requests frequently arrive with small variation we don't allocate
// repeatedly if we get unlucky and they increase over time. By default we
// only allocate here if size > 1MiB. Because goPageSize is a power of 2, we
// can round up efficiently.
allocSize := (size + goPageSize - 1) & ^(goPageSize - 1)
b := make([]byte, size, allocSize)
return &b
}

View File

@@ -19,6 +19,7 @@
package mem
import (
"fmt"
"io"
)
@@ -117,43 +118,36 @@ func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
// Reader returns a new Reader for the input slice after taking references to
// each underlying buffer.
func (s BufferSlice) Reader() Reader {
func (s BufferSlice) Reader() *Reader {
s.Ref()
return &sliceReader{
return &Reader{
data: s,
len: s.Len(),
}
}
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
// with other parts systems. It also provides an additional convenience method
// Remaining(), which returns the number of unread bytes remaining in the slice.
// with other systems.
//
// Buffers will be freed as they are read.
type Reader interface {
io.Reader
io.ByteReader
// Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
Close() error
// Remaining returns the number of unread bytes remaining in the slice.
Remaining() int
// Reset frees the currently held buffer slice and starts reading from the
// provided slice. This allows reusing the reader object.
Reset(s BufferSlice)
}
type sliceReader struct {
//
// A Reader can be constructed from a BufferSlice; alternatively the zero value
// of a Reader may be used after calling Reset on it.
type Reader struct {
data BufferSlice
len int
// The index into data[0].ReadOnlyData().
bufferIdx int
}
func (r *sliceReader) Remaining() int {
// Remaining returns the number of unread bytes remaining in the slice.
func (r *Reader) Remaining() int {
return r.len
}
func (r *sliceReader) Reset(s BufferSlice) {
// Reset frees the currently held buffer slice and starts reading from the
// provided slice. This allows reusing the reader object.
func (r *Reader) Reset(s BufferSlice) {
r.data.Free()
s.Ref()
r.data = s
@@ -161,14 +155,16 @@ func (r *sliceReader) Reset(s BufferSlice) {
r.bufferIdx = 0
}
func (r *sliceReader) Close() error {
// Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
func (r *Reader) Close() error {
r.data.Free()
r.data = nil
r.len = 0
return nil
}
func (r *sliceReader) freeFirstBufferIfEmpty() bool {
func (r *Reader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
return false
}
@@ -179,7 +175,7 @@ func (r *sliceReader) freeFirstBufferIfEmpty() bool {
return true
}
func (r *sliceReader) Read(buf []byte) (n int, _ error) {
func (r *Reader) Read(buf []byte) (n int, _ error) {
if r.len == 0 {
return 0, io.EOF
}
@@ -202,7 +198,8 @@ func (r *sliceReader) Read(buf []byte) (n int, _ error) {
return n, nil
}
func (r *sliceReader) ReadByte() (byte, error) {
// ReadByte reads a single byte.
func (r *Reader) ReadByte() (byte, error) {
if r.len == 0 {
return 0, io.EOF
}
@@ -290,3 +287,59 @@ nextBuffer:
}
}
}
// Discard skips the next n bytes, returning the number of bytes discarded.
//
// It frees buffers as they are fully consumed.
//
// If Discard skips fewer than n bytes, it also returns an error.
func (r *Reader) Discard(n int) (discarded int, err error) {
total := n
for n > 0 && r.len > 0 {
curData := r.data[0].ReadOnlyData()
curSize := min(n, len(curData)-r.bufferIdx)
n -= curSize
r.len -= curSize
r.bufferIdx += curSize
if r.bufferIdx >= len(curData) {
r.data[0].Free()
r.data = r.data[1:]
r.bufferIdx = 0
}
}
discarded = total - n
if n > 0 {
return discarded, fmt.Errorf("insufficient bytes in reader")
}
return discarded, nil
}
// Peek returns the next n bytes without advancing the reader.
//
// Peek appends results to the provided res slice and returns the updated slice.
// This pattern allows re-using the storage of res if it has sufficient
// capacity.
//
// The returned subslices are views into the underlying buffers and are only
// valid until the reader is advanced past the corresponding buffer.
//
// If Peek returns fewer than n bytes, it also returns an error.
func (r *Reader) Peek(n int, res [][]byte) ([][]byte, error) {
for i := 0; n > 0 && i < len(r.data); i++ {
curData := r.data[i].ReadOnlyData()
start := 0
if i == 0 {
start = r.bufferIdx
}
curSize := min(n, len(curData)-start)
if curSize == 0 {
continue
}
res = append(res, curData[start:start+curSize])
n -= curSize
}
if n > 0 {
return nil, fmt.Errorf("insufficient bytes in reader")
}
return res, nil
}

View File

@@ -62,7 +62,6 @@ var (
bufferPoolingThreshold = 1 << 10
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
)
// IsBelowBufferPoolingThreshold returns true if the given size is less than or
@@ -73,9 +72,19 @@ func IsBelowBufferPoolingThreshold(size int) bool {
}
type buffer struct {
refs atomic.Int32
data []byte
// rootBuf is the buffer responsible for returning origData to the pool
// once the reference count drops to 0.
//
// When a buffer is split, the new buffer inherits the rootBuf of the
// original and increments the root's reference count. For the
// initial buffer (the root), this field points to itself.
rootBuf *buffer
// The following fields are only set for root buffers.
origData *[]byte
data []byte
refs *atomic.Int32
pool BufferPool
}
@@ -103,8 +112,8 @@ func NewBuffer(data *[]byte, pool BufferPool) Buffer {
b.origData = data
b.data = *data
b.pool = pool
b.refs = refObjectPool.Get().(*atomic.Int32)
b.refs.Add(1)
b.rootBuf = b
b.refs.Store(1)
return b
}
@@ -127,42 +136,44 @@ func Copy(data []byte, pool BufferPool) Buffer {
}
func (b *buffer) ReadOnlyData() []byte {
if b.refs == nil {
if b.rootBuf == nil {
panic("Cannot read freed buffer")
}
return b.data
}
func (b *buffer) Ref() {
if b.refs == nil {
if b.refs.Add(1) <= 1 {
panic("Cannot ref freed buffer")
}
b.refs.Add(1)
}
func (b *buffer) Free() {
if b.refs == nil {
refs := b.refs.Add(-1)
if refs < 0 {
panic("Cannot free freed buffer")
}
refs := b.refs.Add(-1)
switch {
case refs > 0:
if refs > 0 {
return
case refs == 0:
}
b.data = nil
if b.rootBuf == b {
// This buffer is the owner of the data slice and its ref count reached
// 0, free the slice.
if b.pool != nil {
b.pool.Put(b.origData)
b.pool = nil
}
refObjectPool.Put(b.refs)
b.origData = nil
b.data = nil
b.refs = nil
b.pool = nil
bufferObjectPool.Put(b)
default:
panic("Cannot free freed buffer")
} else {
// This buffer doesn't own the data slice, decrement a ref on the root
// buffer.
b.rootBuf.Free()
}
b.rootBuf = nil
bufferObjectPool.Put(b)
}
func (b *buffer) Len() int {
@@ -170,16 +181,14 @@ func (b *buffer) Len() int {
}
func (b *buffer) split(n int) (Buffer, Buffer) {
if b.refs == nil {
if b.rootBuf == nil || b.rootBuf.refs.Add(1) <= 1 {
panic("Cannot split freed buffer")
}
b.refs.Add(1)
split := newBuffer()
split.origData = b.origData
split.data = b.data[n:]
split.refs = b.refs
split.pool = b.pool
split.rootBuf = b.rootBuf
split.refs.Store(1)
b.data = b.data[:n]
@@ -187,7 +196,7 @@ func (b *buffer) split(n int) (Buffer, Buffer) {
}
func (b *buffer) read(buf []byte) (int, Buffer) {
if b.refs == nil {
if b.rootBuf == nil {
panic("Cannot read freed buffer")
}

View File

@@ -29,7 +29,6 @@ import (
"google.golang.org/grpc/internal/channelz"
istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
)
@@ -48,14 +47,11 @@ type pickerGeneration struct {
// actions and unblock when there's a picker update.
type pickerWrapper struct {
// If pickerGen holds a nil pointer, the pickerWrapper is closed.
pickerGen atomic.Pointer[pickerGeneration]
statsHandlers []stats.Handler // to record blocking picker calls
pickerGen atomic.Pointer[pickerGeneration]
}
func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper {
pw := &pickerWrapper{
statsHandlers: statsHandlers,
}
func newPickerWrapper() *pickerWrapper {
pw := &pickerWrapper{}
pw.pickerGen.Store(&pickerGeneration{
blockingCh: make(chan struct{}),
})
@@ -93,6 +89,12 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) {
}
}
type pick struct {
transport transport.ClientTransport // the selected transport
result balancer.PickResult // the contents of the pick from the LB policy
blocked bool // set if a picker call queued for a new picker
}
// pick returns the transport that will be used for the RPC.
// It may block in the following cases:
// - there's no picker
@@ -100,15 +102,16 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) {
// - the current picker returns other errors and failfast is false.
// - the subConn returned by the current picker is not READY
// When one of these situations happens, pick blocks until the picker gets updated.
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) {
func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (pick, error) {
var ch chan struct{}
var lastPickErr error
pickBlocked := false
for {
pg := pw.pickerGen.Load()
if pg == nil {
return nil, balancer.PickResult{}, ErrClientConnClosing
return pick{}, ErrClientConnClosing
}
if pg.picker == nil {
ch = pg.blockingCh
@@ -127,9 +130,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
}
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr)
return pick{}, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled:
return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr)
return pick{}, status.Error(codes.Canceled, errStr)
}
case <-ch:
}
@@ -145,9 +148,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
// In the second case, the only way it will get to this conditional is
// if there is a new picker.
if ch != nil {
for _, sh := range pw.statsHandlers {
sh.HandleRPC(ctx, &stats.PickerUpdated{})
}
pickBlocked = true
}
ch = pg.blockingCh
@@ -164,7 +165,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if istatus.IsRestrictedControlPlaneCode(st) {
err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err)
}
return nil, balancer.PickResult{}, dropError{error: err}
return pick{}, dropError{error: err}
}
// For all other errors, wait for ready RPCs should block and other
// RPCs should fail with unavailable.
@@ -172,7 +173,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
lastPickErr = err
continue
}
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error())
return pick{}, status.Error(codes.Unavailable, err.Error())
}
acbw, ok := pickResult.SubConn.(*acBalancerWrapper)
@@ -183,9 +184,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
if t := acbw.ac.getReadyTransport(); t != nil {
if channelz.IsOn() {
doneChannelzWrapper(acbw, &pickResult)
return t, pickResult, nil
}
return t, pickResult, nil
return pick{transport: t, result: pickResult, blocked: pickBlocked}, nil
}
if pickResult.Done != nil {
// Calling done with nil error, no bytes sent and no bytes received.

View File

@@ -47,9 +47,6 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
}
// check if the context has the relevant information to prepareMsg
if rpcInfo.preloaderInfo == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil")
}
if rpcInfo.preloaderInfo.codec == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil")
}

View File

@@ -182,6 +182,7 @@ type BuildOptions struct {
// An Endpoint is one network endpoint, or server, which may have multiple
// addresses with which it can be accessed.
// TODO(i/8773) : make resolver.Endpoint and resolver.Address immutable
type Endpoint struct {
// Addresses contains a list of addresses used to access this endpoint.
Addresses []Address
@@ -332,6 +333,11 @@ type AuthorityOverrider interface {
// OverrideAuthority returns the authority to use for a ClientConn with the
// given target. The implementation must generate it without blocking,
// typically in line, and must keep it unchanged.
//
// The returned string must be a valid ":authority" header value, i.e. be
// encoded according to
// [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2) as
// necessary.
OverrideAuthority(Target) string
}

View File

@@ -69,6 +69,7 @@ func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error)
ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil {
errCh <- ctx.Err()
return
}
opts := resolver.BuildOptions{

View File

@@ -33,6 +33,8 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
@@ -41,6 +43,10 @@ import (
"google.golang.org/grpc/status"
)
func init() {
internal.AcceptCompressors = acceptCompressors
}
// Compressor defines the interface gRPC uses to compress a message.
//
// Deprecated: use package encoding.
@@ -151,16 +157,32 @@ func (d *gzipDecompressor) Type() string {
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
compressorName string
failFast bool
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
contentSubtype string
codec baseCodec
maxRetryRPCBufferSize int
onFinish []func(err error)
authority string
acceptedResponseCompressors []string
}
func acceptedCompressorAllows(allowed []string, name string) bool {
if allowed == nil {
return true
}
if name == "" || name == encoding.Identity {
return true
}
for _, a := range allowed {
if a == name {
return true
}
}
return false
}
func defaultCallInfo() *callInfo {
@@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo {
}
}
func newAcceptedCompressionConfig(names []string) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
var allowed []string
seen := make(map[string]struct{}, len(names))
for _, name := range names {
name = strings.TrimSpace(name)
if name == "" || name == encoding.Identity {
continue
}
if !grpcutil.IsCompressorNameRegistered(name) {
return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name)
}
if _, dup := seen[name]; dup {
continue
}
seen[name] = struct{}{}
allowed = append(allowed, name)
}
return allowed, nil
}
// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type CallOption interface {
@@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error {
}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
// acceptCompressors returns a CallOption that limits the compression algorithms
// advertised in the grpc-accept-encoding header for response messages.
// Compression algorithms not in the provided list will not be advertised, and
// responses compressed with non-listed algorithms will be rejected.
func acceptCompressors(names ...string) CallOption {
cp := append([]string(nil), names...)
return acceptCompressorsCallOption{names: cp}
}
// acceptCompressorsCallOption is a CallOption that limits response compression.
type acceptCompressorsCallOption struct {
names []string
}
func (o acceptCompressorsCallOption) before(c *callInfo) error {
allowed, err := newAcceptedCompressionConfig(o.names)
if err != nil {
return err
}
c.acceptedResponseCompressors = allowed
return nil
}
func (acceptCompressorsCallOption) after(*callInfo, *csAttempt) {}
// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
// the wire will be "application/grpc+json". The content-subtype is converted
@@ -657,8 +727,20 @@ type streamReader interface {
Read(n int) (mem.BufferSlice, error)
}
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct {
}
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// parser reads complete gRPC messages from the underlying reader.
type parser struct {
_ noCopy
// r is the underlying reader.
// See the comment on recvMsg for the permissible
// error types.
@@ -845,8 +927,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
@@ -949,7 +1030,7 @@ func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxR
// Information about RPC
type rpcInfo struct {
failfast bool
preloaderInfo *compressorInfo
preloaderInfo compressorInfo
}
// Information about Preloader
@@ -968,7 +1049,7 @@ type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
failfast: failfast,
preloaderInfo: &compressorInfo{
preloaderInfo: compressorInfo{
codec: codec,
cp: cp,
comp: comp,

View File

@@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
istats "google.golang.org/grpc/internal/stats"
@@ -124,7 +125,8 @@ type serviceInfo struct {
// Server is a gRPC server to serve RPC requests.
type Server struct {
opts serverOptions
opts serverOptions
statsHandler stats.Handler
mu sync.Mutex // guards following
lis map[net.Listener]bool
@@ -148,6 +150,8 @@ type Server struct {
serverWorkerChannel chan func()
serverWorkerChannelClose func()
strictPathCheckingLogEmitted atomic.Bool
}
type serverOptions struct {
@@ -692,13 +696,14 @@ func NewServer(opt ...ServerOption) *Server {
o.apply(&opts)
}
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
conns: make(map[string]map[transport.ServerTransport]bool),
services: make(map[string]*serviceInfo),
quit: grpcsync.NewEvent(),
done: grpcsync.NewEvent(),
channelz: channelz.RegisterServer(""),
lis: make(map[net.Listener]bool),
opts: opts,
statsHandler: istats.NewCombinedHandler(opts.statsHandlers...),
conns: make(map[string]map[transport.ServerTransport]bool),
services: make(map[string]*serviceInfo),
quit: grpcsync.NewEvent(),
done: grpcsync.NewEvent(),
channelz: channelz.RegisterServer(""),
}
chainUnaryServerInterceptors(s)
chainStreamServerInterceptors(s)
@@ -921,9 +926,7 @@ func (s *Server) Serve(lis net.Listener) error {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
tempDelay = min(tempDelay, 1*time.Second)
}
s.mu.Lock()
s.printf("Accept error: %v; retrying in %v", err, tempDelay)
@@ -999,7 +1002,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
ConnectionTimeout: s.opts.connectionTimeout,
Credentials: s.opts.creds,
InTapHandle: s.opts.inTapHandle,
StatsHandlers: s.opts.statsHandlers,
StatsHandler: s.statsHandler,
KeepaliveParams: s.opts.keepaliveParams,
KeepalivePolicy: s.opts.keepalivePolicy,
InitialWindowSize: s.opts.initialWindowSize,
@@ -1036,18 +1039,18 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
ctx = transport.SetConnection(ctx, rawConn)
ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
if s.statsHandler != nil {
ctx = s.statsHandler.TagConn(ctx, &stats.ConnTagInfo{
RemoteAddr: st.Peer().Addr,
LocalAddr: st.Peer().LocalAddr,
})
sh.HandleConn(ctx, &stats.ConnBegin{})
s.statsHandler.HandleConn(ctx, &stats.ConnBegin{})
}
defer func() {
st.Close(errors.New("finished serving streams for the server transport"))
for _, sh := range s.opts.statsHandlers {
sh.HandleConn(ctx, &stats.ConnEnd{})
if s.statsHandler != nil {
s.statsHandler.HandleConn(ctx, &stats.ConnEnd{})
}
}()
@@ -1104,7 +1107,7 @@ var _ http.Handler = (*Server)(nil)
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool)
st, err := transport.NewServerHandlerTransport(w, r, s.statsHandler, s.opts.bufferPool)
if err != nil {
// Errors returned from transport.NewServerHandlerTransport have
// already been written to w.
@@ -1198,12 +1201,8 @@ func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStrea
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
}
err = stream.Write(hdr, payload, opts)
if err == nil {
if len(s.opts.statsHandlers) != 0 {
for _, sh := range s.opts.statsHandlers {
sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
}
}
if err == nil && s.statsHandler != nil {
s.statsHandler.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
}
return err
}
@@ -1245,16 +1244,15 @@ func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info
}
func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
shs := s.opts.statsHandlers
if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
sh := s.statsHandler
if sh != nil || trInfo != nil || channelz.IsOn() {
if channelz.IsOn() {
s.incrCallsStarted()
}
var statsBegin *stats.Begin
for _, sh := range shs {
beginTime := time.Now()
if sh != nil {
statsBegin = &stats.Begin{
BeginTime: beginTime,
BeginTime: time.Now(),
IsClientStream: false,
IsServerStream: false,
}
@@ -1282,7 +1280,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
trInfo.tr.Finish()
}
for _, sh := range shs {
if sh != nil {
end := &stats.End{
BeginTime: statsBegin.BeginTime,
EndTime: time.Now(),
@@ -1379,7 +1377,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
}
var payInfo *payloadInfo
if len(shs) != 0 || len(binlogs) != 0 {
if sh != nil || len(binlogs) != 0 {
payInfo = &payloadInfo{}
defer payInfo.free()
}
@@ -1405,7 +1403,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
for _, sh := range shs {
if sh != nil {
sh.HandleRPC(ctx, &stats.InPayload{
RecvTime: time.Now(),
Payload: v,
@@ -1579,32 +1577,30 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
if channelz.IsOn() {
s.incrCallsStarted()
}
shs := s.opts.statsHandlers
sh := s.statsHandler
var statsBegin *stats.Begin
if len(shs) != 0 {
beginTime := time.Now()
if sh != nil {
statsBegin = &stats.Begin{
BeginTime: beginTime,
BeginTime: time.Now(),
IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams,
}
for _, sh := range shs {
sh.HandleRPC(ctx, statsBegin)
}
sh.HandleRPC(ctx, statsBegin)
}
ctx = NewContextWithServerTransportStream(ctx, stream)
ss := &serverStream{
ctx: ctx,
s: stream,
p: &parser{r: stream, bufferPool: s.opts.bufferPool},
p: parser{r: stream, bufferPool: s.opts.bufferPool},
codec: s.getCodec(stream.ContentSubtype()),
desc: sd,
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo,
statsHandler: shs,
statsHandler: sh,
}
if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
if sh != nil || trInfo != nil || channelz.IsOn() {
// See comment in processUnaryRPC on defers.
defer func() {
if trInfo != nil {
@@ -1618,7 +1614,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
ss.mu.Unlock()
}
if len(shs) != 0 {
if sh != nil {
end := &stats.End{
BeginTime: statsBegin.BeginTime,
EndTime: time.Now(),
@@ -1626,9 +1622,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
if err != nil && err != io.EOF {
end.Error = toRPCErr(err)
}
for _, sh := range shs {
sh.HandleRPC(ctx, end)
}
sh.HandleRPC(ctx, end)
}
if channelz.IsOn() {
@@ -1771,6 +1765,24 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv
return ss.s.WriteStatus(statusOK)
}
func (s *Server) handleMalformedMethodName(stream *transport.ServerStream, ti *traceInfo) {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{stream.Method()}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
}
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) {
ctx := stream.Context()
ctx = contextWithServer(ctx, s)
@@ -1791,45 +1803,47 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Ser
}
sm := stream.Method()
if sm != "" && sm[0] == '/' {
if sm == "" {
s.handleMalformedMethodName(stream, ti)
return
}
if sm[0] != '/' {
// TODO(easwars): Add a link to the CVE in the below log messages once
// published.
if envconfig.DisableStrictPathChecking {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream received malformed method name %q. Allowing it because the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING is set to true, but this option will be removed in a future release.", sm)
}
} else {
if old := s.strictPathCheckingLogEmitted.Swap(true); !old {
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream rejected malformed method name %q. To temporarily allow such requests, set the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to true. Note that this is not recommended as it may allow requests to bypass security policies.", sm)
}
s.handleMalformedMethodName(stream, ti)
return
}
} else {
sm = sm[1:]
}
pos := strings.LastIndex(sm, "/")
if pos == -1 {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true)
ti.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil {
if ti != nil {
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError()
}
channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
}
if ti != nil {
ti.tr.Finish()
}
s.handleMalformedMethodName(stream, ti)
return
}
service := sm[:pos]
method := sm[pos+1:]
// FromIncomingContext is expensive: skip if there are no statsHandlers
if len(s.opts.statsHandlers) > 0 {
if s.statsHandler != nil {
md, _ := metadata.FromIncomingContext(ctx)
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
sh.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(),
RemoteAddr: t.Peer().Addr,
LocalAddr: t.Peer().LocalAddr,
Compression: stream.RecvCompress(),
WireLength: stream.HeaderWireLength(),
Header: md,
})
}
ctx = s.statsHandler.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
s.statsHandler.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(),
RemoteAddr: t.Peer().Addr,
LocalAddr: t.Peer().LocalAddr,
Compression: stream.RecvCompress(),
WireLength: stream.HeaderWireLength(),
Header: md,
})
}
// To have calls in stream callouts work. Will delete once all stats handler
// calls come from the gRPC layer.

View File

@@ -64,15 +64,21 @@ func (s *Begin) IsClient() bool { return s.Client }
func (s *Begin) isRPCStats() {}
// PickerUpdated indicates that the LB policy provided a new picker while the
// RPC was waiting for one.
type PickerUpdated struct{}
// DelayedPickComplete indicates that the RPC is unblocked following a delay in
// selecting a connection for the call.
type DelayedPickComplete struct{}
// IsClient indicates if the stats information is from client side. Only Client
// Side interfaces with a Picker, thus always returns true.
func (*PickerUpdated) IsClient() bool { return true }
// IsClient indicates DelayedPickComplete is available on the client.
func (*DelayedPickComplete) IsClient() bool { return true }
func (*PickerUpdated) isRPCStats() {}
func (*DelayedPickComplete) isRPCStats() {}
// PickerUpdated indicates that the RPC is unblocked following a delay in
// selecting a connection for the call.
//
// Deprecated: will be removed in a future release; use DelayedPickComplete
// instead.
type PickerUpdated = DelayedPickComplete
// InPayload contains stats about an incoming payload.
type InPayload struct {

View File

@@ -25,6 +25,7 @@ import (
"math"
rand "math/rand/v2"
"strconv"
"strings"
"sync"
"time"
@@ -51,7 +52,8 @@ import (
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
// StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC.
// execution of a streaming RPC. srv is the service implementation on which the
// RPC was invoked.
//
// If a StreamHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
@@ -177,13 +179,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return cc.NewStream(ctx, desc, method, opts...)
}
var emptyMethodConfig = serviceconfig.MethodConfig{}
// endOfClientStream performs cleanup actions required for both successful and
// failed streams. This includes incrementing channelz stats and invoking all
// registered OnFinish call options.
func endOfClientStream(cc *ClientConn, err error, opts ...CallOption) {
if channelz.IsOn() {
if err != nil {
cc.incrCallsFailed()
} else {
cc.incrCallsSucceeded()
}
}
for _, o := range opts {
if o, ok := o.(OnFinishCallOption); ok {
o.OnFinish(err)
}
}
}
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
if channelz.IsOn() {
cc.incrCallsStarted()
}
defer func() {
if err != nil {
// Ensure cleanup when stream creation fails.
endOfClientStream(cc, err, opts...)
}
}()
// Start tracking the RPC for idleness purposes. This is where a stream is
// created for both streaming and unary RPCs, and hence is a good place to
// track active RPC count.
if err := cc.idlenessMgr.OnCallBegin(); err != nil {
return nil, err
}
cc.idlenessMgr.OnCallBegin()
// Add a calloption, to decrement the active call count, that gets executed
// when the RPC completes.
opts = append([]CallOption{OnFinish(func(error) { cc.idlenessMgr.OnCallEnd() })}, opts...)
@@ -202,14 +234,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
}
}
if channelz.IsOn() {
cc.incrCallsStarted()
defer func() {
if err != nil {
cc.incrCallsFailed()
}
}()
}
// Provide an opportunity for the first RPC to see the first service config
// provided by the resolver.
nameResolutionDelayed, err := cc.waitForResolvedAddrs(ctx)
@@ -217,7 +241,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return nil, err
}
var mc serviceconfig.MethodConfig
mc := &emptyMethodConfig
var onCommit func()
newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, nameResolutionDelayed, opts...)
@@ -240,7 +264,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if rpcConfig.Context != nil {
ctx = rpcConfig.Context
}
mc = rpcConfig.MethodConfig
mc = &rpcConfig.MethodConfig
onCommit = rpcConfig.OnCommitted
if rpcConfig.Interceptor != nil {
rpcInfo.Context = nil
@@ -258,7 +282,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return newStream(ctx, func() {})
}
func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) {
func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc *serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) {
callInfo := defaultCallInfo()
if mc.WaitForReady != nil {
callInfo.failFast = !*mc.WaitForReady
@@ -299,6 +323,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
DoneFunc: doneFunc,
Authority: callInfo.authority,
}
if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 {
headerValue := strings.Join(allowed, ",")
callHdr.AcceptedCompressors = &headerValue
}
// Set our outgoing compression according to the UseCompressor CallOption, if
// set. In that case, also find the compressor from the encoding package.
@@ -325,7 +353,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
cs := &clientStream{
callHdr: callHdr,
ctx: ctx,
methodConfig: &mc,
methodConfig: mc,
opts: opts,
callInfo: callInfo,
cc: cc,
@@ -418,19 +446,21 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.compressorV0, cs.compressorV1)
method := cs.callHdr.Method
var beginTime time.Time
shs := cs.cc.dopts.copts.StatsHandlers
for _, sh := range shs {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast, NameResolutionDelay: cs.nameResolutionDelay})
sh := cs.cc.statsHandler
if sh != nil {
beginTime = time.Now()
begin := &stats.Begin{
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{
FullMethodName: method, FailFast: cs.callInfo.failFast,
NameResolutionDelay: cs.nameResolutionDelay,
})
sh.HandleRPC(ctx, &stats.Begin{
Client: true,
BeginTime: beginTime,
FailFast: cs.callInfo.failFast,
IsClientStream: cs.desc.ClientStreams,
IsServerStream: cs.desc.ServerStreams,
IsTransparentRetryAttempt: isTransparent,
}
sh.HandleRPC(ctx, begin)
})
}
var trInfo *traceInfo
@@ -461,7 +491,7 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
beginTime: beginTime,
cs: cs,
decompressorV0: cs.cc.dopts.dc,
statsHandlers: shs,
statsHandler: sh,
trInfo: trInfo,
}, nil
}
@@ -469,8 +499,9 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error)
func (a *csAttempt) getTransport() error {
cs := a.cs
var err error
a.transport, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method)
pickInfo := balancer.PickInfo{Ctx: a.ctx, FullMethodName: cs.callHdr.Method}
pick, err := cs.cc.pickerWrapper.pick(a.ctx, cs.callInfo.failFast, pickInfo)
a.transport, a.pickResult = pick.transport, pick.result
if err != nil {
if de, ok := err.(dropError); ok {
err = de.error
@@ -479,7 +510,10 @@ func (a *csAttempt) getTransport() error {
return err
}
if a.trInfo != nil {
a.trInfo.firstLine.SetRemoteAddr(a.transport.RemoteAddr())
a.trInfo.firstLine.SetRemoteAddr(a.transport.Peer().Addr)
}
if pick.blocked && a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, &stats.DelayedPickComplete{})
}
return nil
}
@@ -504,9 +538,17 @@ func (a *csAttempt) newStream() error {
md, _ := metadata.FromOutgoingContext(a.ctx)
md = metadata.Join(md, a.pickResult.Metadata)
a.ctx = metadata.NewOutgoingContext(a.ctx, md)
}
s, err := a.transport.NewStream(a.ctx, cs.callHdr)
// If the `CallAuthority` CallOption is not set, check if the LB picker
// has provided an authority override in the PickResult metadata and
// apply it, as specified in gRFC A81.
if cs.callInfo.authority == "" {
if authMD := a.pickResult.Metadata.Get(":authority"); len(authMD) > 0 {
cs.callHdr.Authority = authMD[0]
}
}
}
s, err := a.transport.NewStream(a.ctx, cs.callHdr, a.statsHandler)
if err != nil {
nse, ok := err.(*transport.NewStreamError)
if !ok {
@@ -523,7 +565,7 @@ func (a *csAttempt) newStream() error {
}
a.transportStream = s
a.ctx = s.Context()
a.parser = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
a.parser = parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
return nil
}
@@ -543,6 +585,8 @@ type clientStream struct {
sentLast bool // sent an end stream
receivedFirstMsg bool // set after the first message is received
methodConfig *MethodConfig
ctx context.Context // the application's context, wrapped by stats/tracing
@@ -593,7 +637,7 @@ type csAttempt struct {
cs *clientStream
transport transport.ClientTransport
transportStream *transport.ClientStream
parser *parser
parser parser
pickResult balancer.PickResult
finished bool
@@ -607,8 +651,8 @@ type csAttempt struct {
// and cleared when the finish method is called.
trInfo *traceInfo
statsHandlers []stats.Handler
beginTime time.Time
statsHandler stats.Handler
beginTime time.Time
// set for newStream errors that may be transparently retried
allowTransparentRetry bool
@@ -1032,9 +1076,6 @@ func (cs *clientStream) finish(err error) {
return
}
cs.finished = true
for _, onFinish := range cs.callInfo.onFinish {
onFinish(err)
}
cs.commitAttemptLocked()
if cs.attempt != nil {
cs.attempt.finish(err)
@@ -1074,13 +1115,7 @@ func (cs *clientStream) finish(err error) {
if err == nil {
cs.retryThrottler.successfulRPC()
}
if channelz.IsOn() {
if err != nil {
cs.cc.incrCallsFailed()
} else {
cs.cc.incrCallsSucceeded()
}
}
endOfClientStream(cs.cc, err, cs.opts...)
cs.cancel()
}
@@ -1102,17 +1137,15 @@ func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength
}
return io.EOF
}
if len(a.statsHandlers) != 0 {
for _, sh := range a.statsHandlers {
sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
}
if a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
}
return nil
}
func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
cs := a.cs
if len(a.statsHandlers) != 0 && payInfo == nil {
if a.statsHandler != nil && payInfo == nil {
payInfo = &payloadInfo{}
defer payInfo.free()
}
@@ -1126,6 +1159,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
a.decompressorV0 = nil
a.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
a.decompressorV0 = nil
@@ -1133,16 +1170,21 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream.
a.decompressorSet = true
}
if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil {
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil {
if err == io.EOF {
if statusErr := a.transportStream.Status().Err(); statusErr != nil {
return statusErr
}
// Received no msg and status OK for non-server streaming rpcs.
if !cs.desc.ServerStreams && !cs.receivedFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no response message from non-server-streaming RPC")
}
return io.EOF // indicates successful end of stream.
}
return toRPCErr(err)
}
cs.receivedFirstMsg = true
if a.trInfo != nil {
a.mu.Lock()
if a.trInfo.tr != nil {
@@ -1150,8 +1192,8 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
a.mu.Unlock()
}
for _, sh := range a.statsHandlers {
sh.HandleRPC(a.ctx, &stats.InPayload{
if a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{
Client: true,
RecvTime: time.Now(),
Payload: m,
@@ -1166,12 +1208,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF {
if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF {
return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
}
return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
return status.Error(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
}
func (a *csAttempt) finish(err error) {
@@ -1204,15 +1246,14 @@ func (a *csAttempt) finish(err error) {
ServerLoad: balancerload.Parse(tr),
})
}
for _, sh := range a.statsHandlers {
end := &stats.End{
if a.statsHandler != nil {
a.statsHandler.HandleRPC(a.ctx, &stats.End{
Client: true,
BeginTime: a.beginTime,
EndTime: time.Now(),
Trailer: tr,
Error: err,
}
sh.HandleRPC(a.ctx, end)
})
}
if a.trInfo != nil && a.trInfo.tr != nil {
if err == nil {
@@ -1309,16 +1350,18 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
codec: c.codec,
sendCompressorV0: cp,
sendCompressorV1: comp,
decompressorV0: ac.cc.dopts.dc,
transport: t,
}
s, err := as.transport.NewStream(as.ctx, as.callHdr)
// nil stats handler: internal streams like health and ORCA do not support telemetry.
s, err := as.transport.NewStream(as.ctx, as.callHdr, nil)
if err != nil {
err = toRPCErr(err)
return nil, err
}
as.transportStream = s
as.parser = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
as.parser = parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
ac.incrCallsStarted()
if desc != unaryStreamDesc {
// Listen on stream context to cleanup when the stream context is
@@ -1353,6 +1396,7 @@ type addrConnStream struct {
transport transport.ClientTransport
ctx context.Context
sentLast bool
receivedFirstMsg bool
desc *StreamDesc
codec baseCodec
sendCompressorV0 Compressor
@@ -1360,7 +1404,7 @@ type addrConnStream struct {
decompressorSet bool
decompressorV0 Decompressor
decompressorV1 encoding.Compressor
parser *parser
parser parser
// mu guards finished and is held for the entire finish method.
mu sync.Mutex
@@ -1466,6 +1510,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
as.decompressorV0 = nil
as.decompressorV1 = encoding.GetCompressor(ct)
}
// Validate that the compression method is acceptable for this call.
if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) {
return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct)
}
} else {
// No compression is used; disable our decompressor.
as.decompressorV0 = nil
@@ -1473,15 +1521,20 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream.
as.decompressorSet = true
}
if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil {
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil {
if err == io.EOF {
if statusErr := as.transportStream.Status().Err(); statusErr != nil {
return statusErr
}
// Received no msg and status OK for non-server streaming rpcs.
if !as.desc.ServerStreams && !as.receivedFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no response message from non-server-streaming RPC")
}
return io.EOF // indicates successful end of stream.
}
return toRPCErr(err)
}
as.receivedFirstMsg = true
if as.desc.ServerStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
@@ -1490,12 +1543,12 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF {
if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF {
return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success
} else if err != nil {
return toRPCErr(err)
}
return status.Errorf(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
return status.Error(codes.Internal, "cardinality violation: expected <EOF> for non server-streaming RPCs, but received another message")
}
func (as *addrConnStream) finish(err error) {
@@ -1578,8 +1631,9 @@ type ServerStream interface {
type serverStream struct {
ctx context.Context
s *transport.ServerStream
p *parser
p parser
codec baseCodec
desc *StreamDesc
compressorV0 Compressor
compressorV1 encoding.Compressor
@@ -1588,11 +1642,13 @@ type serverStream struct {
sendCompressorName string
recvFirstMsg bool // set after the first message is received
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
statsHandler []stats.Handler
statsHandler stats.Handler
binlogs []binarylog.MethodLogger
// serverHeaderBinlogged indicates whether server header has been logged. It
@@ -1728,10 +1784,8 @@ func (ss *serverStream) SendMsg(m any) (err error) {
binlog.Log(ss.ctx, sm)
}
}
if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
}
if ss.statsHandler != nil {
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
}
return nil
}
@@ -1762,11 +1816,11 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
}
}()
var payInfo *payloadInfo
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
if ss.statsHandler != nil || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{}
defer payInfo.free()
}
if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil {
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil {
if err == io.EOF {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
@@ -1774,6 +1828,10 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, chc)
}
}
// Received no request msg for non-client streaming rpcs.
if !ss.desc.ClientStreams && !ss.recvFirstMsg {
return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-streaming RPC")
}
return err
}
if err == io.ErrUnexpectedEOF {
@@ -1781,16 +1839,15 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
}
return toRPCErr(err)
}
if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), &stats.InPayload{
RecvTime: time.Now(),
Payload: m,
Length: payInfo.uncompressedBytes.Len(),
WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength,
})
}
ss.recvFirstMsg = true
if ss.statsHandler != nil {
ss.statsHandler.HandleRPC(ss.s.Context(), &stats.InPayload{
RecvTime: time.Now(),
Payload: m,
Length: payInfo.uncompressedBytes.Len(),
WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength,
})
}
if len(ss.binlogs) != 0 {
cm := &binarylog.ClientMessage{
@@ -1800,7 +1857,19 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
binlog.Log(ss.ctx, cm)
}
}
return nil
if ss.desc.ClientStreams {
// Subsequent messages should be received by subsequent RecvMsg calls.
return nil
}
// Special handling for non-client-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF {
return nil
} else if err != nil {
return err
}
return status.Error(codes.Internal, "cardinality violation: received multiple request messages for non-client-streaming RPC")
}
// MethodFromServerStream returns the method string for the input stream.

View File

@@ -19,4 +19,4 @@
package grpc
// Version is the current grpc version.
const Version = "1.74.2"
const Version = "1.79.3"

View File

@@ -72,9 +72,10 @@ type (
EditionFeatures EditionFeatures
}
FileL2 struct {
Options func() protoreflect.ProtoMessage
Imports FileImports
Locations SourceLocations
Options func() protoreflect.ProtoMessage
Imports FileImports
OptionImports func() protoreflect.FileImports
Locations SourceLocations
}
// EditionFeatures is a frequently-instantiated struct, so please take care
@@ -126,12 +127,9 @@ func (fd *File) ParentFile() protoreflect.FileDescriptor { return fd }
func (fd *File) Parent() protoreflect.Descriptor { return nil }
func (fd *File) Index() int { return 0 }
func (fd *File) Syntax() protoreflect.Syntax { return fd.L1.Syntax }
// Not exported and just used to reconstruct the original FileDescriptor proto
func (fd *File) Edition() int32 { return int32(fd.L1.Edition) }
func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() }
func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package }
func (fd *File) IsPlaceholder() bool { return false }
func (fd *File) Name() protoreflect.Name { return fd.L1.Package.Name() }
func (fd *File) FullName() protoreflect.FullName { return fd.L1.Package }
func (fd *File) IsPlaceholder() bool { return false }
func (fd *File) Options() protoreflect.ProtoMessage {
if f := fd.lazyInit().Options; f != nil {
return f()
@@ -150,6 +148,16 @@ func (fd *File) Format(s fmt.State, r rune) { descfmt.FormatD
func (fd *File) ProtoType(protoreflect.FileDescriptor) {}
func (fd *File) ProtoInternal(pragma.DoNotImplement) {}
// The next two are not part of the FileDescriptor interface. They are just used to reconstruct
// the original FileDescriptor proto.
func (fd *File) Edition() int32 { return int32(fd.L1.Edition) }
func (fd *File) OptionImports() protoreflect.FileImports {
if f := fd.lazyInit().OptionImports; f != nil {
return f()
}
return emptyFiles
}
func (fd *File) lazyInit() *FileL2 {
if atomic.LoadUint32(&fd.once) == 0 {
fd.lazyInitOnce()
@@ -182,9 +190,9 @@ type (
L2 *EnumL2 // protected by fileDesc.once
}
EnumL1 struct {
eagerValues bool // controls whether EnumL2.Values is already populated
EditionFeatures EditionFeatures
Visibility int32
eagerValues bool // controls whether EnumL2.Values is already populated
}
EnumL2 struct {
Options func() protoreflect.ProtoMessage
@@ -219,6 +227,11 @@ func (ed *Enum) ReservedNames() protoreflect.Names { return &ed.lazyInit()
func (ed *Enum) ReservedRanges() protoreflect.EnumRanges { return &ed.lazyInit().ReservedRanges }
func (ed *Enum) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, ed) }
func (ed *Enum) ProtoType(protoreflect.EnumDescriptor) {}
// This is not part of the EnumDescriptor interface. It is just used to reconstruct
// the original FileDescriptor proto.
func (ed *Enum) Visibility() int32 { return ed.L1.Visibility }
func (ed *Enum) lazyInit() *EnumL2 {
ed.L0.ParentFile.lazyInit() // implicitly initializes L2
return ed.L2
@@ -244,13 +257,13 @@ type (
L2 *MessageL2 // protected by fileDesc.once
}
MessageL1 struct {
Enums Enums
Messages Messages
Extensions Extensions
IsMapEntry bool // promoted from google.protobuf.MessageOptions
IsMessageSet bool // promoted from google.protobuf.MessageOptions
Enums Enums
Messages Messages
Extensions Extensions
EditionFeatures EditionFeatures
Visibility int32
IsMapEntry bool // promoted from google.protobuf.MessageOptions
IsMessageSet bool // promoted from google.protobuf.MessageOptions
}
MessageL2 struct {
Options func() protoreflect.ProtoMessage
@@ -319,6 +332,11 @@ func (md *Message) Messages() protoreflect.MessageDescriptors { return &md.L
func (md *Message) Extensions() protoreflect.ExtensionDescriptors { return &md.L1.Extensions }
func (md *Message) ProtoType(protoreflect.MessageDescriptor) {}
func (md *Message) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, md) }
// This is not part of the MessageDescriptor interface. It is just used to reconstruct
// the original FileDescriptor proto.
func (md *Message) Visibility() int32 { return md.L1.Visibility }
func (md *Message) lazyInit() *MessageL2 {
md.L0.ParentFile.lazyInit() // implicitly initializes L2
return md.L2

View File

@@ -284,6 +284,13 @@ func (ed *Enum) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protorefl
case genid.EnumDescriptorProto_Value_field_number:
numValues++
}
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.EnumDescriptorProto_Visibility_field_number:
ed.L1.Visibility = int32(v)
}
default:
m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:]
@@ -365,6 +372,13 @@ func (md *Message) unmarshalSeed(b []byte, sb *strs.Builder, pf *File, pd protor
md.unmarshalSeedOptions(v)
}
prevField = num
case protowire.VarintType:
v, m := protowire.ConsumeVarint(b)
b = b[m:]
switch num {
case genid.DescriptorProto_Visibility_field_number:
md.L1.Visibility = int32(v)
}
default:
m := protowire.ConsumeFieldValue(num, typ, b)
b = b[m:]

View File

@@ -134,6 +134,7 @@ func (fd *File) unmarshalFull(b []byte) {
var enumIdx, messageIdx, extensionIdx, serviceIdx int
var rawOptions []byte
var optionImports []string
fd.L2 = new(FileL2)
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
@@ -157,6 +158,8 @@ func (fd *File) unmarshalFull(b []byte) {
imp = PlaceholderFile(path)
}
fd.L2.Imports = append(fd.L2.Imports, protoreflect.FileImport{FileDescriptor: imp})
case genid.FileDescriptorProto_OptionDependency_field_number:
optionImports = append(optionImports, sb.MakeString(v))
case genid.FileDescriptorProto_EnumType_field_number:
fd.L1.Enums.List[enumIdx].unmarshalFull(v, sb)
enumIdx++
@@ -178,6 +181,23 @@ func (fd *File) unmarshalFull(b []byte) {
}
}
fd.L2.Options = fd.builder.optionsUnmarshaler(&descopts.File, rawOptions)
if len(optionImports) > 0 {
var imps FileImports
var once sync.Once
fd.L2.OptionImports = func() protoreflect.FileImports {
once.Do(func() {
imps = make(FileImports, len(optionImports))
for i, path := range optionImports {
imp, _ := fd.builder.FileRegistry.FindFileByPath(path)
if imp == nil {
imp = PlaceholderFile(path)
}
imps[i] = protoreflect.FileImport{FileDescriptor: imp}
}
})
return &imps
}
}
}
func (ed *Enum) unmarshalFull(b []byte, sb *strs.Builder) {

View File

@@ -13,8 +13,10 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)
var defaultsCache = make(map[Edition]EditionFeatures)
var defaultsKeys = []Edition{}
var (
defaultsCache = make(map[Edition]EditionFeatures)
defaultsKeys = []Edition{}
)
func init() {
unmarshalEditionDefaults(editiondefaults.Defaults)
@@ -41,7 +43,7 @@ func unmarshalGoFeature(b []byte, parent EditionFeatures) EditionFeatures {
b = b[m:]
parent.StripEnumPrefix = int(v)
default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling GoFeatures", num))
panic(fmt.Sprintf("unknown field number %d while unmarshalling GoFeatures", num))
}
}
return parent
@@ -76,7 +78,7 @@ func unmarshalFeatureSet(b []byte, parent EditionFeatures) EditionFeatures {
// DefaultSymbolVisibility is enforced in protoc, runtimes should not
// inspect this value.
default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling FeatureSet", num))
panic(fmt.Sprintf("unknown field number %d while unmarshalling FeatureSet", num))
}
case protowire.BytesType:
v, m := protowire.ConsumeBytes(b)
@@ -150,7 +152,7 @@ func unmarshalEditionDefaults(b []byte) {
_, m := protowire.ConsumeVarint(b)
b = b[m:]
default:
panic(fmt.Sprintf("unkown field number %d while unmarshalling EditionDefault", num))
panic(fmt.Sprintf("unknown field number %d while unmarshalling EditionDefault", num))
}
}
}

View File

@@ -27,6 +27,7 @@ const (
Api_SourceContext_field_name protoreflect.Name = "source_context"
Api_Mixins_field_name protoreflect.Name = "mixins"
Api_Syntax_field_name protoreflect.Name = "syntax"
Api_Edition_field_name protoreflect.Name = "edition"
Api_Name_field_fullname protoreflect.FullName = "google.protobuf.Api.name"
Api_Methods_field_fullname protoreflect.FullName = "google.protobuf.Api.methods"
@@ -35,6 +36,7 @@ const (
Api_SourceContext_field_fullname protoreflect.FullName = "google.protobuf.Api.source_context"
Api_Mixins_field_fullname protoreflect.FullName = "google.protobuf.Api.mixins"
Api_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Api.syntax"
Api_Edition_field_fullname protoreflect.FullName = "google.protobuf.Api.edition"
)
// Field numbers for google.protobuf.Api.
@@ -46,6 +48,7 @@ const (
Api_SourceContext_field_number protoreflect.FieldNumber = 5
Api_Mixins_field_number protoreflect.FieldNumber = 6
Api_Syntax_field_number protoreflect.FieldNumber = 7
Api_Edition_field_number protoreflect.FieldNumber = 8
)
// Names for google.protobuf.Method.
@@ -63,6 +66,7 @@ const (
Method_ResponseStreaming_field_name protoreflect.Name = "response_streaming"
Method_Options_field_name protoreflect.Name = "options"
Method_Syntax_field_name protoreflect.Name = "syntax"
Method_Edition_field_name protoreflect.Name = "edition"
Method_Name_field_fullname protoreflect.FullName = "google.protobuf.Method.name"
Method_RequestTypeUrl_field_fullname protoreflect.FullName = "google.protobuf.Method.request_type_url"
@@ -71,6 +75,7 @@ const (
Method_ResponseStreaming_field_fullname protoreflect.FullName = "google.protobuf.Method.response_streaming"
Method_Options_field_fullname protoreflect.FullName = "google.protobuf.Method.options"
Method_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Method.syntax"
Method_Edition_field_fullname protoreflect.FullName = "google.protobuf.Method.edition"
)
// Field numbers for google.protobuf.Method.
@@ -82,6 +87,7 @@ const (
Method_ResponseStreaming_field_number protoreflect.FieldNumber = 5
Method_Options_field_number protoreflect.FieldNumber = 6
Method_Syntax_field_number protoreflect.FieldNumber = 7
Method_Edition_field_number protoreflect.FieldNumber = 8
)
// Names for google.protobuf.Mixin.

View File

@@ -52,7 +52,7 @@ import (
const (
Major = 1
Minor = 36
Patch = 7
Patch = 10
PreRelease = ""
)