Skip to content

Commit a96b6e6

Browse files
committed
chore: merge get groups sql queries into 1
1 parent 6f9b1a3 commit a96b6e6

File tree

10 files changed

+100
-191
lines changed

10 files changed

+100
-191
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,19 +1411,16 @@ func (q *querier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uui
14111411
return memberCount, nil
14121412
}
14131413

1414-
func (q *querier) GetGroups(ctx context.Context) ([]database.Group, error) {
1415-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
1416-
return nil, err
1414+
func (q *querier) GetGroups(ctx context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
1415+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err == nil {
1416+
// Optimize this query for system users as it is used in telemetry.
1417+
// Calling authz on all groups in a deployment for telemetry jobs is
1418+
// excessive. Most user calls should have some filtering applied to reduce
1419+
// the size of the set.
1420+
return q.db.GetGroups(ctx, arg)
14171421
}
1418-
return q.db.GetGroups(ctx)
1419-
}
1420-
1421-
func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
1422-
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationAndUserID)(ctx, arg)
1423-
}
14241422

1425-
func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {
1426-
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
1423+
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroups)(ctx, arg)
14271424
}
14281425

14291426
func (q *querier) GetHealthSettings(ctx context.Context) (string, error) {

coderd/database/dbmem/dbmem.go

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,51 +2599,38 @@ func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, groupID
25992599
return int64(len(users)), nil
26002600
}
26012601

2602-
func (q *FakeQuerier) GetGroups(_ context.Context) ([]database.Group, error) {
2603-
q.mutex.RLock()
2604-
defer q.mutex.RUnlock()
2605-
2606-
out := make([]database.Group, len(q.groups))
2607-
copy(out, q.groups)
2608-
return out, nil
2609-
}
2610-
2611-
func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
2602+
func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
26122603
err := validateDatabaseType(arg)
26132604
if err != nil {
26142605
return nil, err
26152606
}
26162607

26172608
q.mutex.RLock()
26182609
defer q.mutex.RUnlock()
2619-
var groupIDs []uuid.UUID
2620-
for _, member := range q.groupMembers {
2621-
if member.UserID == arg.UserID {
2622-
groupIDs = append(groupIDs, member.GroupID)
2623-
}
2624-
}
2625-
groups := []database.Group{}
2626-
for _, group := range q.groups {
2627-
if slices.Contains(groupIDs, group.ID) && group.OrganizationID == arg.OrganizationID {
2628-
groups = append(groups, group)
2610+
2611+
groupIDs := make(map[uuid.UUID]struct{})
2612+
if arg.HasMemberID != uuid.Nil {
2613+
for _, member := range q.groupMembers {
2614+
if member.UserID == arg.HasMemberID {
2615+
groupIDs[member.GroupID] = struct{}{}
2616+
}
26292617
}
26302618
}
26312619

2632-
return groups, nil
2633-
}
2634-
2635-
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, id uuid.UUID) ([]database.Group, error) {
2636-
q.mutex.RLock()
2637-
defer q.mutex.RUnlock()
2638-
2639-
groups := make([]database.Group, 0, len(q.groups))
2620+
filtered := make([]database.Group, 0)
26402621
for _, group := range q.groups {
2641-
if group.OrganizationID == id {
2642-
groups = append(groups, group)
2622+
if arg.OrganizationID != uuid.Nil && group.OrganizationID != arg.OrganizationID {
2623+
continue
2624+
}
2625+
_, ok := groupIDs[group.ID]
2626+
if arg.HasMemberID != uuid.Nil && !ok {
2627+
continue
26432628
}
2629+
2630+
filtered = append(filtered, group)
26442631
}
26452632

2646-
return groups, nil
2633+
return filtered, nil
26472634
}
26482635

26492636
func (q *FakeQuerier) GetHealthSettings(_ context.Context) (string, error) {

coderd/database/dbmetrics/dbmetrics.go

Lines changed: 2 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 32 additions & 96 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/groups.sql

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
-- name: GetGroups :many
2-
SELECT * FROM groups;
3-
41
-- name: GetGroupByID :one
52
SELECT
63
*
@@ -23,30 +20,34 @@ AND
2320
LIMIT
2421
1;
2522

26-
-- name: GetGroupsByOrganizationID :many
27-
SELECT
28-
*
29-
FROM
30-
groups
31-
WHERE
32-
organization_id = $1;
33-
34-
-- name: GetGroupsByOrganizationAndUserID :many
23+
-- name: GetGroups :many
3524
SELECT
36-
groups.*
25+
*
3726
FROM
3827
groups
3928
WHERE
40-
groups.id IN (
41-
SELECT
42-
group_id
43-
FROM
44-
group_members_expanded gme
45-
WHERE
46-
gme.user_id = @user_id
47-
AND
48-
gme.organization_id = @organization_id
49-
);
29+
true
30+
AND CASE
31+
WHEN @organization_id:: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
32+
groups.organization_id = @organization_id
33+
ELSE true
34+
END
35+
AND CASE
36+
-- Filter to only include groups a user is a member of
37+
WHEN @has_member_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
38+
EXISTS (
39+
SELECT
40+
1
41+
FROM
42+
group_members
43+
WHERE
44+
group_members.group_id = groups.id
45+
AND
46+
group_members.user_id = @has_member_id
47+
)
48+
ELSE true
49+
END
50+
;
5051

5152
-- name: InsertGroup :one
5253
INSERT INTO groups (
@@ -68,15 +69,15 @@ INSERT INTO groups (
6869
id,
6970
name,
7071
organization_id,
71-
source
72+
source
7273
)
7374
SELECT
74-
gen_random_uuid(),
75-
group_name,
76-
@organization_id,
77-
@source
75+
gen_random_uuid(),
76+
group_name,
77+
@organization_id,
78+
@source
7879
FROM
79-
UNNEST(@group_names :: text[]) AS group_name
80+
UNNEST(@group_names :: text[]) AS group_name
8081
-- If the name conflicts, do nothing.
8182
ON CONFLICT DO NOTHING
8283
RETURNING *;

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,8 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
481481
ownerSSHPublicKey = ownerSSHKey.PublicKey
482482
ownerSSHPrivateKey = ownerSSHKey.PrivateKey
483483
}
484-
ownerGroups, err := s.Database.GetGroupsByOrganizationAndUserID(ctx, database.GetGroupsByOrganizationAndUserIDParams{
485-
UserID: owner.ID,
484+
ownerGroups, err := s.Database.GetGroups(ctx, database.GetGroupsParams{
485+
HasMemberID: owner.ID,
486486
OrganizationID: s.OrganizationID,
487487
})
488488
if err != nil {

coderd/telemetry/telemetry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
367367
return nil
368368
})
369369
eg.Go(func() error {
370-
groups, err := r.options.Database.GetGroups(ctx)
370+
groups, err := r.options.Database.GetGroups(ctx, database.GetGroupsParams{})
371371
if err != nil {
372372
return xerrors.Errorf("get groups: %w", err)
373373
}

0 commit comments

Comments
 (0)