Skip to content

chore: refactor entitlements to be a safe object to use #14406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"tailscale.com/util/singleflight"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/quartz"
"github.com/coder/serpent"

Expand Down Expand Up @@ -157,6 +158,9 @@ type Options struct {
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
RefreshEntitlements func(ctx context.Context) error
// Entitlements can come from the enterprise caller if enterprise code is
// included.
Entitlements *entitlements.Set
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
// after a successful authentication.
// This is somewhat janky, but seemingly the only reasonable way to add a header
Expand Down Expand Up @@ -263,6 +267,9 @@ func New(options *Options) *API {
if options == nil {
options = &Options{}
}
if options.Entitlements == nil {
options.Entitlements = entitlements.New()
}
if options.NewTicker == nil {
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
ticker := time.NewTicker(duration)
Expand Down Expand Up @@ -500,6 +507,7 @@ func New(options *Options) *API {
DocsURL: options.DeploymentValues.DocsURL.String(),
AppearanceFetcher: &api.AppearanceFetcher,
BuildInfo: buildInfo,
Entitlements: options.Entitlements,
})
api.SiteHandler.Experiments.Store(&experiments)

Expand Down
109 changes: 109 additions & 0 deletions coderd/entitlements/entitlements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package entitlements

import (
"encoding/json"
"net/http"
"sync"
"time"

"github.com/coder/coder/v2/codersdk"
)

type Set struct {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented the methods as I saw them used. There might be a way to reduce the number of methods on this struct.

entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
}

func New() *Set {
return &Set{
// Some defaults for an unlicensed instance.
// These will be updated when coderd is initialized.
entitlements: codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{},
Warnings: nil,
Errors: nil,
HasLicense: false,
Trial: false,
RequireTelemetry: false,
RefreshedAt: time.Time{},
},
}
}

// AllowRefresh returns whether the entitlements are allowed to be refreshed.
// If it returns false, that means it was recently refreshed and the caller should
// wait the returned duration before trying again.
func (l *Set) AllowRefresh(now time.Time) (bool, time.Duration) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

diff := now.Sub(l.entitlements.RefreshedAt)
if diff < time.Minute {
return false, time.Minute - diff
}

return true, 0
}

func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

f, ok := l.entitlements.Features[name]
return f, ok
}

func (l *Set) Enabled(feature codersdk.FeatureName) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential follow-up: we could replace this with f, ok := Features(name); ok && f.Enabled?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Because before we had access to the whole struct, our usage of it seemed a bit arbitrary at times. Sometimes we grab it and check entitled, most times just enabled.

I'm not trying to fix all our usage right now, but it would be good to audit at some times.

l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

f, ok := l.entitlements.Features[feature]
if !ok {
return false
}
return f.Enabled
}

// AsJSON is used to return this to the api without exposing the entitlements for
// mutation.
func (l *Set) AsJSON() json.RawMessage {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

b, _ := json.Marshal(l.entitlements)
return b
}

func (l *Set) Replace(entitlements codersdk.Entitlements) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

l.entitlements = entitlements
}

func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

do(&l.entitlements)
}

func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

oldFeature := l.entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
}

func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

for _, warning := range l.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}
}
63 changes: 63 additions & 0 deletions coderd/entitlements/entitlements_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package entitlements_test

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/codersdk"
)

func TestUpdate(t *testing.T) {
t.Parallel()

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))

set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
Enabled: true,
Entitlement: codersdk.EntitlementEntitled,
}
})
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
}

func TestAllowRefresh(t *testing.T) {
t.Parallel()

now := time.Now()
set := entitlements.New()
set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now
})

ok, wait := set.AllowRefresh(now)
require.False(t, ok)
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)

set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now.Add(time.Minute * -2)
})

ok, wait = set.AllowRefresh(now)
require.True(t, ok)
require.Equal(t, time.Duration(0), wait)
}

func TestReplace(t *testing.T) {
t.Parallel()

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
set.Replace(codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{
codersdk.FeatureMultipleOrganizations: {
Enabled: true,
},
},
})
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
}
6 changes: 6 additions & 0 deletions codersdk/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ const (
EntitlementNotEntitled Entitlement = "not_entitled"
)

// Entitled returns if the entitlement can be used. So this is true if it
// is entitled or still in it's grace period.
func (e Entitlement) Entitled() bool {
return e == EntitlementEntitled || e == EntitlementGracePeriod
}

// Weight converts the enum types to a numerical value for easier
// comparisons. Easier than sets of if statements.
func (e Entitlement) Weight() int {
Expand Down
72 changes: 30 additions & 42 deletions enterprise/coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/entitlements"
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/enterprise/coderd/portsharing"
Expand Down Expand Up @@ -103,19 +104,26 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
return nil, xerrors.Errorf("init database encryption: %w", err)
}

entitlementsSet := entitlements.New()
options.Database = cryptDB
api := &API{
ctx: ctx,
cancel: cancelFunc,
Options: options,
ctx: ctx,
cancel: cancelFunc,
Options: options,
entitlements: entitlementsSet,
provisionerDaemonAuth: &provisionerDaemonAuth{
psk: options.ProvisionerDaemonPSK,
authorizer: options.Authorizer,
db: options.Database,
},
licenseMetricsCollector: &license.MetricsCollector{
Entitlements: entitlementsSet,
},
}
// This must happen before coderd initialization!
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
options.Options.Entitlements = api.entitlements
api.AGPL = coderd.New(options.Options)
defer func() {
if err != nil {
Expand Down Expand Up @@ -493,7 +501,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
api.AGPL.WorkspaceProxiesFetchUpdater.Store(&fetchUpdater)

err = api.PrometheusRegistry.Register(&api.licenseMetricsCollector)
err = api.PrometheusRegistry.Register(api.licenseMetricsCollector)
if err != nil {
return nil, xerrors.Errorf("unable to register license metrics collector")
}
Expand Down Expand Up @@ -553,13 +561,11 @@ type API struct {
// ProxyHealth checks the reachability of all workspace proxies.
ProxyHealth *proxyhealth.ProxyHealth

entitlementsUpdateMu sync.Mutex
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
entitlements *entitlements.Set

provisionerDaemonAuth *provisionerDaemonAuth

licenseMetricsCollector license.MetricsCollector
licenseMetricsCollector *license.MetricsCollector
tailnetService *tailnet.ClientService
}

Expand Down Expand Up @@ -588,11 +594,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade
// has no roles. This is a normal user!
return
}
api.entitlementsMu.RLock()
defer api.entitlementsMu.RUnlock()
for _, warning := range api.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}

api.entitlements.WriteEntitlementWarningHeaders(header)
}

func (api *API) Close() error {
Expand All @@ -614,9 +617,6 @@ func (api *API) Close() error {
}

func (api *API) updateEntitlements(ctx context.Context) error {
api.entitlementsUpdateMu.Lock()
defer api.entitlementsUpdateMu.Unlock()

replicas := api.replicaManager.AllPrimary()
agedReplicas := make([]database.Replica, 0, len(replicas))
for _, replica := range replicas {
Expand All @@ -632,7 +632,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
agedReplicas = append(agedReplicas, replica)
}

entitlements, err := license.Entitlements(
reloadedEntitlements, err := license.Entitlements(
ctx, api.Database,
len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
Expand All @@ -652,29 +652,24 @@ func (api *API) updateEntitlements(ctx context.Context) error {
return err
}

if entitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
if reloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
// We can't fail because then the user couldn't remove the offending
// license w/o a restart.
//
// We don't simply append to entitlement.Errors since we don't want any
// enterprise features enabled.
api.entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
api.entitlements.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
})

api.Logger.Error(ctx, "license requires telemetry enabled")
return nil
}

featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) {
if api.entitlements.Features == nil {
return true, false, entitlements.Features[featureName].Enabled
}
oldFeature := api.entitlements.Features[featureName]
newFeature := entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName])
}

shouldUpdate := func(initial, changed, enabled bool) bool {
Expand Down Expand Up @@ -831,20 +826,16 @@ func (api *API) updateEntitlements(ctx context.Context) error {
}

// External token encryption is soft-enforced
featureExternalTokenEncryption := entitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption := reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0
if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled {
msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize())
api.Logger.Warn(ctx, msg)
entitlements.Warnings = append(entitlements.Warnings, msg)
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
}
entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption

api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
api.entitlements = entitlements
api.licenseMetricsCollector.Entitlements.Store(&entitlements)
api.AGPL.SiteHandler.Entitlements.Store(&entitlements)
api.entitlements.Replace(reloadedEntitlements)
return nil
}

Expand Down Expand Up @@ -1024,10 +1015,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(*
// @Router /entitlements [get]
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.entitlementsMu.RLock()
entitlements := api.entitlements
api.entitlementsMu.RUnlock()
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON())
}

func (api *API) runEntitlementsLoop(ctx context.Context) {
Expand Down
Loading
Loading