Skip to content

chore: rework RPC version negotiation #15687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions vpn/speaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"google.golang.org/protobuf/proto"

"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
)

type SpeakerRole string
Expand Down Expand Up @@ -258,7 +257,7 @@ func handshake(
// read and write simultaneously to avoid deadlocking if the conn is not buffered
errCh := make(chan error, 2)
go func() {
ours := headerString(CurrentVersion, me)
ours := headerString(me, CurrentSupportedVersions)
_, err := conn.Write([]byte(ours))
logger.Debug(ctx, "wrote out header")
if err != nil {
Expand Down Expand Up @@ -316,34 +315,43 @@ func handshake(
}
}
logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader))
err := validateHeader(theirHeader, them)
gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions)
if err != nil {
return xerrors.Errorf("validate header (%s): %w", theirHeader, err)
}
logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion))
// TODO: actually use the common version to perform different behavior once
// we have multiple versions
return nil
}

const headerPreamble = "codervpn"

func headerString(version *apiversion.APIVersion, role SpeakerRole) string {
return fmt.Sprintf("%s %s %s\n", headerPreamble, version.String(), role)
func headerString(role SpeakerRole, versions RPCVersionList) string {
return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String())
}

func validateHeader(header string, expectedRole SpeakerRole) error {
func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) {
parts := strings.Split(header, " ")
if len(parts) != 3 {
return xerrors.New("wrong number of parts")
return RPCVersion{}, xerrors.New("wrong number of parts")
}
if parts[0] != headerPreamble {
return xerrors.New("invalid preamble")
return RPCVersion{}, xerrors.New("invalid preamble")
}
if err := CurrentVersion.Validate(parts[1]); err != nil {
return xerrors.Errorf("version: %w", err)
if parts[1] != string(expectedRole) {
return RPCVersion{}, xerrors.New("unexpected role")
}
if parts[2] != string(expectedRole) {
return xerrors.New("unexpected role")
otherVersions, err := ParseRPCVersionList(parts[2])
if err != nil {
return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err)
}
return nil
compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions)
if !ok {
return RPCVersion{},
xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String())
}
return compatibleVersion, nil
}

type request[S rpcMessage, R rpcMessage] struct {
Expand Down
18 changes: 9 additions & 9 deletions vpn/speaker_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ func TestSpeaker_RawPeer(t *testing.T) {
errCh <- err
}()

expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"

b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))

_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
_, err = mp.Write([]byte("codervpn manager 1.3,2.1\n"))
require.NoError(t, err)

err = testutil.RequireRecvCtx(ctx, t, errCh)
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestSpeaker_OversizeHandshake(t *testing.T) {
errCh <- err
}()

expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"

b := make([]byte, 256)
n, err := mp.Read(b)
Expand All @@ -177,10 +177,10 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
for _, tc := range []struct {
name, handshake string
}{
{name: "preamble", handshake: "ssh 1.0 manager\n"},
{name: "preamble", handshake: "ssh manager 1.0\n"},
{name: "2components", handshake: "ssh manager\n"},
{name: "newversion", handshake: "codervpn 1.1 manager\n"},
{name: "oldversion", handshake: "codervpn 0.1 manager\n"},
{name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"},
{name: "0version", handshake: "codervpn 0.1 manager\n"},
{name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"},
{name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"},
} {
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
_, err = mp.Write([]byte(tc.handshake))
require.NoError(t, err)

expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
Expand Down Expand Up @@ -246,14 +246,14 @@ func TestSpeaker_CorruptMessage(t *testing.T) {
errCh <- err
}()

expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"

b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))

_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
_, err = mp.Write([]byte("codervpn manager 1.0\n"))
require.NoError(t, err)

err = testutil.RequireRecvCtx(ctx, t, errCh)
Expand Down
141 changes: 136 additions & 5 deletions vpn/version.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,141 @@
package vpn

import "github.com/coder/coder/v2/apiversion"
import (
"fmt"
"strconv"
"strings"

const (
CurrentMajor = 1
CurrentMinor = 0
"golang.org/x/xerrors"
)

var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
// CurrentSupportedVersions is the list of versions supported by this
// implementation of the VPN RPC protocol.
var CurrentSupportedVersions = RPCVersionList{
Versions: []RPCVersion{
{Major: 1, Minor: 0},
},
}

// RPCVersion represents a single version of the RPC protocol. Any given version
// is expected to be backwards compatible with all previous minor versions on
// the same major version.
//
// e.g. RPCVersion{2, 3} is backwards compatible with RPCVersion{2, 2} but is
// not backwards compatible with RPCVersion{1, 2}.
type RPCVersion struct {
Major uint64 `json:"major"`
Minor uint64 `json:"minor"`
}

// ParseRPCVersion parses a version string in the format "major.minor" into a
// RPCVersion.
func ParseRPCVersion(str string) (RPCVersion, error) {
split := strings.Split(str, ".")
if len(split) != 2 {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
major, err := strconv.ParseUint(split[0], 10, 64)
if err != nil {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
if major == 0 {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
minor, err := strconv.ParseUint(split[1], 10, 64)
if err != nil {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
return RPCVersion{Major: major, Minor: minor}, nil
}

func (v RPCVersion) String() string {
return fmt.Sprintf("%d.%d", v.Major, v.Minor)
}

// IsCompatibleWith returns the lowest version that is compatible with both
// versions. If the versions are not compatible, the second return value will be
// false.
func (v RPCVersion) IsCompatibleWith(other RPCVersion) (RPCVersion, bool) {
if v.Major != other.Major {
return RPCVersion{}, false
}
// The lowest minor version from the two versions should be returned.
if v.Minor < other.Minor {
return v, true
}
return other, true
}

// RPCVersionList represents a list of RPC versions supported by a RPC peer. An
type RPCVersionList struct {
Versions []RPCVersion `json:"versions"`
}

// ParseRPCVersionList parses a version string in the format
// "major.minor,major.minor" into a RPCVersionList.
func ParseRPCVersionList(str string) (RPCVersionList, error) {
split := strings.Split(str, ",")
versions := make([]RPCVersion, len(split))
for i, v := range split {
version, err := ParseRPCVersion(v)
if err != nil {
return RPCVersionList{}, xerrors.Errorf("invalid version list: %s", str)
}
versions[i] = version
}
vl := RPCVersionList{Versions: versions}
err := vl.Validate()
if err != nil {
return RPCVersionList{}, xerrors.Errorf("invalid parsed version list %q: %w", str, err)
}
return vl, nil
}

func (vl RPCVersionList) String() string {
versionStrings := make([]string, len(vl.Versions))
for i, v := range vl.Versions {
versionStrings[i] = v.String()
}
return strings.Join(versionStrings, ",")
}

// Validate returns an error if the version list is not sorted or contains
// duplicate major versions.
func (vl RPCVersionList) Validate() error {
if len(vl.Versions) == 0 {
return xerrors.New("no versions")
}
for i := 0; i < len(vl.Versions); i++ {
if vl.Versions[i].Major == 0 {
return xerrors.Errorf("invalid version: %s", vl.Versions[i].String())
}
if i > 0 && vl.Versions[i-1].Major == vl.Versions[i].Major {
return xerrors.Errorf("duplicate major version: %d", vl.Versions[i].Major)
}
if i > 0 && vl.Versions[i-1].Major > vl.Versions[i].Major {
return xerrors.Errorf("versions are not sorted")
}
}
return nil
}

// IsCompatibleWith returns the lowest version that is compatible with both
// version lists. If the versions are not compatible, the second return value
// will be false.
func (vl RPCVersionList) IsCompatibleWith(other RPCVersionList) (RPCVersion, bool) {
bestVersion := RPCVersion{}
for _, v1 := range vl.Versions {
for _, v2 := range other.Versions {
if v1.Major == v2.Major && v1.Major > bestVersion.Major {
v, ok := v1.IsCompatibleWith(v2)
if ok {
bestVersion = v
}
}
}
}
if bestVersion.Major == 0 {
return bestVersion, false
}
return bestVersion, true
}
Loading
Loading