Skip to content

Commit d03fe73

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent 8cec6d1 commit d03fe73

File tree

14 files changed

+663
-5
lines changed

14 files changed

+663
-5
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,21 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13231323
return q.db.CleanTailnetTunnels(ctx)
13241324
}
13251325

1326+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1327+
// Just like the actual query, shortcut if the user is an owner.
1328+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1329+
if err == nil {
1330+
return q.db.CountConnectionLogs(ctx, arg)
1331+
}
1332+
1333+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1334+
if err != nil {
1335+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1336+
}
1337+
1338+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1339+
}
1340+
13261341
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13271342
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13281343
return nil, err
@@ -5301,3 +5316,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database
53015316
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53025317
return q.GetConnectionLogsOffset(ctx, arg)
53035318
}
5319+
5320+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5321+
return q.CountConnectionLogs(ctx, arg)
5322+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
391391
LimitOpt: 10,
392392
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
393393
}))
394+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
395+
ws := createWorkspace(s.T(), db)
396+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
397+
Type: database.ConnectionTypeSsh,
398+
WorkspaceID: ws.ID,
399+
OrganizationID: ws.OrganizationID,
400+
WorkspaceOwnerID: ws.OwnerID,
401+
})
402+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
403+
Type: database.ConnectionTypeSsh,
404+
WorkspaceID: ws.ID,
405+
OrganizationID: ws.OrganizationID,
406+
WorkspaceOwnerID: ws.OwnerID,
407+
})
408+
check.Args(database.CountConnectionLogsParams{}).Asserts(
409+
rbac.ResourceConnectionLog, policy.ActionRead,
410+
).WithNotAuthorized("nil")
411+
}))
412+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
413+
ws := createWorkspace(s.T(), db)
414+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
415+
Type: database.ConnectionTypeSsh,
416+
WorkspaceID: ws.ID,
417+
OrganizationID: ws.OrganizationID,
418+
WorkspaceOwnerID: ws.OwnerID,
419+
})
420+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
421+
Type: database.ConnectionTypeSsh,
422+
WorkspaceID: ws.ID,
423+
OrganizationID: ws.OrganizationID,
424+
WorkspaceOwnerID: ws.OwnerID,
425+
})
426+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
427+
rbac.ResourceConnectionLog, policy.ActionRead,
428+
)
429+
}))
394430
}
395431

396432
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
271271

272272
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
273273
// any case where the error is nil and the response is an empty slice.
274-
if err != nil || !hasEmptySliceResponse(resp) {
274+
if err != nil || !hasEmptyResponse(resp) {
275275
// Expect the default error
276276
if testCase.notAuthorizedExpect == "" {
277277
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
@@ -297,7 +297,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
297297

298298
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
299299
// any case where the error is nil and the response is an empty slice.
300-
if err != nil || !hasEmptySliceResponse(resp) {
300+
if err != nil || !hasEmptyResponse(resp) {
301301
if testCase.cancelledCtxExpect == "" {
302302
s.Errorf(err, "method should an error with cancellation")
303303
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
@@ -308,13 +308,20 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
308308
})
309309
}
310310

311-
func hasEmptySliceResponse(values []reflect.Value) bool {
311+
func hasEmptyResponse(values []reflect.Value) bool {
312312
for _, r := range values {
313313
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
314314
if r.Len() == 0 {
315315
return true
316316
}
317317
}
318+
319+
// Special case for int64, as it's the return type for count queries.
320+
if r.Kind() == reflect.Int64 {
321+
if r.Int() == 0 {
322+
return true
323+
}
324+
}
318325
}
319326
return false
320327
}

coderd/database/dbmem/dbmem.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,10 @@ func (*FakeQuerier) CleanTailnetTunnels(context.Context) error {
17801780
return ErrUnimplemented
17811781
}
17821782

1783+
func (q *FakeQuerier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1784+
return q.CountAuthorizedConnectionLogs(ctx, arg, nil)
1785+
}
1786+
17831787
func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
17841788
return nil, ErrUnimplemented
17851789
}
@@ -14156,3 +14160,93 @@ func (q *FakeQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
1415614160

1415714161
return logs, nil
1415814162
}
14163+
14164+
func (q *FakeQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
14165+
if err := validateDatabaseType(arg); err != nil {
14166+
return 0, err
14167+
}
14168+
14169+
// Call this to match the same function calls as the SQL implementation.
14170+
// It functionally does nothing for filtering.
14171+
if prepared != nil {
14172+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
14173+
VariableConverter: regosql.ConnectionLogConverter(),
14174+
})
14175+
if err != nil {
14176+
return 0, err
14177+
}
14178+
}
14179+
14180+
q.mutex.RLock()
14181+
defer q.mutex.RUnlock()
14182+
14183+
var count int64
14184+
14185+
for _, clog := range q.connectionLogs {
14186+
if arg.OrganizationID != uuid.Nil && clog.OrganizationID != arg.OrganizationID {
14187+
continue
14188+
}
14189+
if arg.WorkspaceOwner != "" {
14190+
workspaceOwner, err := q.getUserByIDNoLock(clog.WorkspaceOwnerID)
14191+
if err == nil && !strings.EqualFold(arg.WorkspaceOwner, workspaceOwner.Username) {
14192+
continue
14193+
}
14194+
}
14195+
if arg.Type != "" && string(clog.Type) != arg.Type {
14196+
continue
14197+
}
14198+
if arg.UserID != uuid.Nil && (!clog.UserID.Valid || clog.UserID.UUID != arg.UserID) {
14199+
continue
14200+
}
14201+
if arg.Username != "" {
14202+
if !clog.UserID.Valid {
14203+
continue
14204+
}
14205+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14206+
if err != nil || user.Username != arg.Username {
14207+
continue
14208+
}
14209+
}
14210+
if arg.Email != "" {
14211+
if !clog.UserID.Valid {
14212+
continue
14213+
}
14214+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14215+
if err != nil || user.Email != arg.Email {
14216+
continue
14217+
}
14218+
}
14219+
if !arg.StartedAfter.IsZero() && clog.Time.Before(arg.StartedAfter) {
14220+
continue
14221+
}
14222+
if !arg.StartedBefore.IsZero() && clog.Time.After(arg.StartedBefore) {
14223+
continue
14224+
}
14225+
if !arg.ClosedAfter.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.Before(arg.ClosedAfter)) {
14226+
continue
14227+
}
14228+
if !arg.ClosedBefore.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.After(arg.ClosedBefore)) {
14229+
continue
14230+
}
14231+
if arg.WorkspaceID != uuid.Nil && clog.WorkspaceID != arg.WorkspaceID {
14232+
continue
14233+
}
14234+
if arg.ConnectionID != uuid.Nil && (!clog.ConnectionID.Valid || clog.ConnectionID.UUID != arg.ConnectionID) {
14235+
continue
14236+
}
14237+
if arg.Status != "" {
14238+
isConnected := !clog.CloseTime.Valid
14239+
if (arg.Status == "connected" && !isConnected) || (arg.Status == "disconnected" && isConnected) {
14240+
continue
14241+
}
14242+
}
14243+
14244+
if prepared != nil && prepared.Authorize(ctx, clog.RBACObject()) != nil {
14245+
continue
14246+
}
14247+
14248+
count++
14249+
}
14250+
14251+
return count, nil
14252+
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
566566

567567
type connectionLogQuerier interface {
568568
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
569+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
569570
}
570571

571572
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -653,6 +654,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
653654
return items, nil
654655
}
655656

657+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
658+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
659+
VariableConverter: regosql.ConnectionLogConverter(),
660+
})
661+
if err != nil {
662+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
663+
}
664+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
665+
if err != nil {
666+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
667+
}
668+
669+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
670+
rows, err := q.db.QueryContext(ctx, query,
671+
arg.OrganizationID,
672+
arg.WorkspaceOwner,
673+
arg.Type,
674+
arg.UserID,
675+
arg.Username,
676+
arg.Email,
677+
arg.StartedAfter,
678+
arg.StartedBefore,
679+
arg.ClosedAfter,
680+
arg.ClosedBefore,
681+
arg.WorkspaceID,
682+
arg.ConnectionID,
683+
arg.Status,
684+
)
685+
if err != nil {
686+
return 0, err
687+
}
688+
defer rows.Close()
689+
var count int64
690+
for rows.Next() {
691+
if err := rows.Scan(&count); err != nil {
692+
return 0, err
693+
}
694+
}
695+
if err := rows.Close(); err != nil {
696+
return 0, err
697+
}
698+
if err := rows.Err(); err != nil {
699+
return 0, err
700+
}
701+
return count, nil
702+
}
703+
656704
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
657705
if !strings.Contains(query, authorizedQueryPlaceholder) {
658706
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package database
22

33
import (
4+
"regexp"
5+
"strings"
46
"testing"
57
"time"
68

@@ -54,3 +56,35 @@ func TestWorkspaceTableConvert(t *testing.T) {
5456
"'workspace.WorkspaceTable()' is not missing at least 1 field when converting to 'WorkspaceTable'. "+
5557
"To resolve this, go to the 'func (w Workspace) WorkspaceTable()' and ensure all fields are converted.")
5658
}
59+
60+
func TestConnectionLogsQueryConsistency(t *testing.T) {
61+
t.Parallel()
62+
63+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
64+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
65+
66+
countWhereClause := extractWhereClause(countConnectionLogs)
67+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
68+
69+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
70+
}
71+
72+
// extractWhereClause extracts the WHERE clause from a SQL query string
73+
func extractWhereClause(query string) string {
74+
// Find WHERE and get everything after it
75+
wherePattern := regexp.MustCompile(`(?is)WHERE\s+(.*)`)
76+
whereMatches := wherePattern.FindStringSubmatch(query)
77+
if len(whereMatches) < 2 {
78+
return ""
79+
}
80+
81+
whereClause := whereMatches[1]
82+
83+
// Remove ORDER BY, LIMIT, OFFSET clauses from the end
84+
whereClause = regexp.MustCompile(`(?is)\s+(ORDER BY|LIMIT|OFFSET).*$`).ReplaceAllString(whereClause, "")
85+
86+
// Remove SQL comments
87+
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
88+
89+
return strings.TrimSpace(whereClause)
90+
}

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)