Skip to content

Commit 033f95e

Browse files
committed
chore: implement device auth flow for fake idp
1 parent 1183cc4 commit 033f95e

File tree

2 files changed

+188
-28
lines changed

2 files changed

+188
-28
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 180 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13+
"math/rand"
1314
"mime"
1415
"net"
1516
"net/http"
1617
"net/http/cookiejar"
1718
"net/http/httptest"
1819
"net/url"
20+
"strconv"
1921
"strings"
2022
"testing"
2123
"time"
@@ -47,6 +49,13 @@ type token struct {
4749
exp time.Time
4850
}
4951

52+
type deviceFlow struct {
53+
// userInput is the expected input to authenticate the device flow.
54+
userInput string
55+
exp time.Time
56+
granted bool
57+
}
58+
5059
// FakeIDP is a functional OIDC provider.
5160
// It only supports 1 OIDC client.
5261
type FakeIDP struct {
@@ -79,6 +88,9 @@ type FakeIDP struct {
7988
refreshTokens *syncmap.Map[string, string]
8089
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
8190
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
91+
// Device flow
92+
deviceCode *syncmap.Map[string, deviceFlow]
93+
deviceCodeInput *syncmap.Map[string, externalauth.ExchangeDeviceCodeResponse]
8294

8395
// hooks
8496
// hookValidRedirectURL can be used to reject a redirect url from the
@@ -229,6 +241,7 @@ const (
229241
keysPath = "/oauth2/keys"
230242
userInfoPath = "/oauth2/userinfo"
231243
deviceAuth = "/login/device/code"
244+
deviceVerify = "/login/device"
232245
)
233246

234247
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
@@ -249,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
249262
refreshTokensUsed: syncmap.New[string, bool](),
250263
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
251264
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
265+
deviceCode: syncmap.New[string, deviceFlow](),
252266
hookOnRefresh: func(_ string) error { return nil },
253267
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
254268
hookValidRedirectURL: func(redirectURL string) error { return nil },
@@ -291,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
291305
// ProviderJSON is the JSON representation of the OpenID Connect provider
292306
// These are all the urls that the IDP will respond to.
293307
f.provider = ProviderJSON{
294-
Issuer: issuer,
295-
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
296-
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
297-
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
298-
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
308+
Issuer: issuer,
309+
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
310+
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
311+
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
312+
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
313+
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
299314
Algorithms: []string{
300315
"RS256",
301316
},
@@ -539,12 +554,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
539554

540555
// ProviderJSON is the .well-known/configuration JSON
541556
type ProviderJSON struct {
542-
Issuer string `json:"issuer"`
543-
AuthURL string `json:"authorization_endpoint"`
544-
TokenURL string `json:"token_endpoint"`
545-
JWKSURL string `json:"jwks_uri"`
546-
UserInfoURL string `json:"userinfo_endpoint"`
547-
Algorithms []string `json:"id_token_signing_alg_values_supported"`
557+
Issuer string `json:"issuer"`
558+
AuthURL string `json:"authorization_endpoint"`
559+
TokenURL string `json:"token_endpoint"`
560+
JWKSURL string `json:"jwks_uri"`
561+
UserInfoURL string `json:"userinfo_endpoint"`
562+
DeviceCodeURL string `json:"device_authorization_endpoint"`
563+
Algorithms []string `json:"id_token_signing_alg_values_supported"`
548564
// This is custom
549565
ExternalAuthURL string `json:"external_auth_url"`
550566
}
@@ -712,8 +728,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
712728
}))
713729

714730
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
715-
values, err := f.authenticateOIDCClientRequest(t, r)
731+
var values url.Values
732+
var err error
733+
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
734+
values = r.URL.Query()
735+
} else {
736+
values, err = f.authenticateOIDCClientRequest(t, r)
737+
}
716738
f.logger.Info(r.Context(), "http idp call token",
739+
slog.F("url", r.URL.String()),
717740
slog.F("valid", err == nil),
718741
slog.F("grant_type", values.Get("grant_type")),
719742
slog.F("values", values.Encode()),
@@ -789,6 +812,35 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
789812
f.refreshTokens.Delete(refreshToken)
790813
case "urn:ietf:params:oauth:grant-type:device_code":
791814
// Device flow
815+
var resp externalauth.ExchangeDeviceCodeResponse
816+
deviceCode := values.Get("device_code")
817+
if deviceCode == "" {
818+
resp.Error = "invalid_request"
819+
resp.ErrorDescription = "missing device_code"
820+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
821+
return
822+
}
823+
824+
deviceFlow, ok := f.deviceCode.Load(deviceCode)
825+
if !ok {
826+
resp.Error = "invalid_request"
827+
resp.ErrorDescription = "device_code provided not found"
828+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
829+
return
830+
}
831+
832+
if !deviceFlow.granted {
833+
// Status code ok with the error as pending.
834+
resp.Error = "authorization_pending"
835+
resp.ErrorDescription = ""
836+
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
837+
return
838+
}
839+
840+
// Would be nice to get an actual email here.
841+
claims = jwt.MapClaims{
842+
"email": "unknown-dev-auth",
843+
}
792844
default:
793845
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
794846
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
@@ -812,8 +864,19 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
812864
// Store the claims for the next refresh
813865
f.refreshIDTokenClaims.Store(refreshToken, claims)
814866

815-
rw.Header().Set("Content-Type", "application/json")
816-
_ = json.NewEncoder(rw).Encode(token)
867+
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
868+
rw.Header().Set("Content-Type", "application/json")
869+
_ = json.NewEncoder(rw).Encode(token)
870+
return
871+
}
872+
873+
// Default to form encode. Just to make sure our code sets the right headers.
874+
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
875+
vals := url.Values{}
876+
for k, v := range token {
877+
vals.Set(k, fmt.Sprintf("%v", v))
878+
}
879+
_, _ = rw.Write([]byte(vals.Encode()))
817880
}))
818881

819882
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
@@ -891,10 +954,68 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
891954
_ = json.NewEncoder(rw).Encode(set)
892955
}))
893956

957+
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
958+
f.logger.Info(r.Context(), "http call device verify")
959+
960+
inputParam := "user_input"
961+
userInput := r.URL.Query().Get(inputParam)
962+
if userInput == "" {
963+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
964+
Message: "Invalid user input",
965+
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
966+
})
967+
return
968+
}
969+
970+
deviceCode := r.URL.Query().Get("device_code")
971+
if deviceCode == "" {
972+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
973+
Message: "Invalid device code",
974+
Detail: "Hit this url again with ?device_code=<device_code>",
975+
})
976+
return
977+
}
978+
979+
flow, ok := f.deviceCode.Load(deviceCode)
980+
if !ok {
981+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
982+
Message: "Invalid device code",
983+
Detail: "Device code not found.",
984+
})
985+
return
986+
}
987+
988+
if time.Now().After(flow.exp) {
989+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
990+
Message: "Invalid device code",
991+
Detail: "Device code expired.",
992+
})
993+
return
994+
}
995+
996+
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
997+
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
998+
Message: "Invalid device code",
999+
Detail: "user code does not match",
1000+
})
1001+
return
1002+
}
1003+
1004+
f.deviceCode.Store(deviceCode, deviceFlow{
1005+
userInput: flow.userInput,
1006+
exp: flow.exp,
1007+
granted: true,
1008+
})
1009+
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
1010+
Message: "Device authenticated!",
1011+
})
1012+
}))
1013+
8941014
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
1015+
f.logger.Info(r.Context(), "http call device auth")
1016+
8951017
p := httpapi.NewQueryParamParser()
8961018
p.Required("client_id")
897-
p.Required("scopes")
8981019
clientID := p.String(r.URL.Query(), "", "client_id")
8991020
_ = p.String(r.URL.Query(), "", "scopes")
9001021
if len(p.Errors) > 0 {
@@ -912,24 +1033,42 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
9121033
return
9131034
}
9141035

1036+
deviceCode := uuid.NewString()
1037+
lifetime := time.Second * 900
1038+
flow := deviceFlow{
1039+
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
1040+
}
1041+
f.deviceCode.Store(deviceCode, deviceFlow{
1042+
userInput: flow.userInput,
1043+
exp: time.Now().Add(lifetime),
1044+
})
1045+
1046+
verifyURL := f.issuerURL.ResolveReference(&url.URL{
1047+
Path: deviceVerify,
1048+
RawQuery: url.Values{
1049+
"device_code": {deviceCode},
1050+
"user_input": {flow.userInput},
1051+
}.Encode(),
1052+
}).String()
1053+
9151054
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
9161055
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
917-
"device_code": uuid.NewString(),
918-
"user_code": "1234",
919-
"verification_uri": "",
920-
"expires_in": 900,
921-
"interval": 0,
1056+
"device_code": deviceCode,
1057+
"user_code": flow.userInput,
1058+
"verification_uri": verifyURL,
1059+
"expires_in": int(lifetime.Seconds()),
1060+
"interval": 3,
9221061
})
9231062
return
9241063
}
9251064

9261065
// By default, GitHub form encodes these.
9271066
_, _ = fmt.Fprint(rw, url.Values{
928-
"device_code": {uuid.NewString()},
929-
"user_code": {"1234"},
930-
"verification_uri": {""},
931-
"expires_in": {"900"},
932-
"interval": {"0"},
1067+
"device_code": {deviceCode},
1068+
"user_code": {flow.userInput},
1069+
"verification_uri": {verifyURL},
1070+
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
1071+
"interval": {"3"},
9331072
})
9341073
}))
9351074

@@ -1034,6 +1173,8 @@ type ExternalAuthConfigOptions struct {
10341173
// completely customize the response. It captures all routes under the /external-auth-validate/*
10351174
// so the caller can do whatever they want and even add routes.
10361175
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
1176+
1177+
UseDeviceAuth bool
10371178
}
10381179

10391180
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
@@ -1080,17 +1221,30 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
10801221
}
10811222
}
10821223
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
1224+
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
10831225
cfg := &externalauth.Config{
10841226
DisplayName: id,
1085-
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
1227+
InstrumentedOAuth2Config: oauthCfg,
10861228
ID: id,
10871229
// No defaults for these fields by omitting the type
10881230
Type: "",
10891231
DisplayIcon: f.WellknownConfig().UserInfoURL,
10901232
// Omit the /user for the validate so we can easily append to it when modifying
10911233
// the cfg for advanced tests.
10921234
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
1235+
DeviceAuth: &externalauth.DeviceAuth{
1236+
Config: oauthCfg,
1237+
ClientID: f.clientID,
1238+
TokenURL: f.provider.TokenURL,
1239+
Scopes: []string{},
1240+
CodeURL: f.provider.DeviceCodeURL,
1241+
},
10931242
}
1243+
1244+
if !custom.UseDeviceAuth {
1245+
cfg.DeviceAuth = nil
1246+
}
1247+
10941248
for _, opt := range opts {
10951249
opt(cfg)
10961250
}

scripts/testidp/main.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ var (
2323
expiry = flag.Duration("expiry", time.Minute*5, "Token expiry")
2424
clientID = flag.String("client-id", "static-client-id", "Client ID, set empty to be random")
2525
clientSecret = flag.String("client-sec", "static-client-secret", "Client Secret, set empty to be random")
26+
deviceFlow = flag.Bool("device-flow", false, "Enable device flow")
2627
// By default, no regex means it will never match anything. So at least default to matching something.
2728
extRegex = flag.String("ext-regex", `^(https?://)?example\.com(/.*)?$`, "External auth regex")
2829
)
@@ -66,14 +67,18 @@ func RunIDP() func(t *testing.T) {
6667
id, sec := idp.AppCredentials()
6768
prov := idp.WellknownConfig()
6869
const appID = "fake"
69-
coderCfg := idp.ExternalAuthConfig(t, appID, nil)
70+
coderCfg := idp.ExternalAuthConfig(t, appID, &oidctest.ExternalAuthConfigOptions{
71+
UseDeviceAuth: *deviceFlow,
72+
})
7073

7174
log.Println("IDP Issuer URL", idp.IssuerURL())
7275
log.Println("Coderd Flags")
76+
7377
deviceCodeURL := ""
7478
if coderCfg.DeviceAuth != nil {
7579
deviceCodeURL = coderCfg.DeviceAuth.CodeURL
7680
}
81+
7782
cfg := withClientSecret{
7883
ClientSecret: sec,
7984
ExternalAuthConfig: codersdk.ExternalAuthConfig{
@@ -89,13 +94,14 @@ func RunIDP() func(t *testing.T) {
8994
NoRefresh: false,
9095
Scopes: []string{"openid", "email", "profile"},
9196
ExtraTokenKeys: coderCfg.ExtraTokenKeys,
92-
DeviceFlow: coderCfg.DeviceAuth != nil,
97+
DeviceFlow: *deviceFlow,
9398
DeviceCodeURL: deviceCodeURL,
9499
Regex: *extRegex,
95100
DisplayName: coderCfg.DisplayName,
96101
DisplayIcon: coderCfg.DisplayIcon,
97102
},
98103
}
104+
99105
data, err := json.Marshal([]withClientSecret{cfg})
100106
require.NoError(t, err)
101107
log.Printf(`--external-auth-providers='%s'`, string(data))

0 commit comments

Comments
 (0)