Skip to content

Commit 65040d1

Browse files
separate org and user search
1 parent 87ad621 commit 65040d1

File tree

3 files changed

+250
-83
lines changed

3 files changed

+250
-83
lines changed

pkg/github/search.go

Lines changed: 123 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -168,99 +168,139 @@ type MinimalSearchUsersResult struct {
168168
Items []MinimalUser `json:"items"`
169169
}
170170

171-
// SearchUsers creates a tool to search for GitHub users.
172-
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
173-
return mcp.NewTool("search_users",
174-
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")),
175-
mcp.WithToolAnnotation(mcp.ToolAnnotation{
176-
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
177-
ReadOnlyHint: ToBoolPtr(true),
178-
}),
179-
mcp.WithString("q",
180-
mcp.Required(),
181-
mcp.Description("Search query using GitHub users search syntax"),
182-
),
183-
mcp.WithString("sort",
184-
mcp.Description("Sort field by category"),
185-
mcp.Enum("followers", "repositories", "joined"),
186-
),
187-
mcp.WithString("order",
188-
mcp.Description("Sort order"),
189-
mcp.Enum("asc", "desc"),
190-
),
191-
WithPagination(),
192-
),
193-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
194-
query, err := RequiredParam[string](request, "q")
195-
if err != nil {
196-
return mcp.NewToolResultError(err.Error()), nil
197-
}
198-
sort, err := OptionalParam[string](request, "sort")
199-
if err != nil {
200-
return mcp.NewToolResultError(err.Error()), nil
201-
}
202-
order, err := OptionalParam[string](request, "order")
203-
if err != nil {
204-
return mcp.NewToolResultError(err.Error()), nil
205-
}
206-
pagination, err := OptionalPaginationParams(request)
207-
if err != nil {
208-
return mcp.NewToolResultError(err.Error()), nil
209-
}
171+
func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHandlerFunc {
172+
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
173+
query, err := RequiredParam[string](request, "query")
174+
if err != nil {
175+
return mcp.NewToolResultError(err.Error()), nil
176+
}
177+
sort, err := OptionalParam[string](request, "sort")
178+
if err != nil {
179+
return mcp.NewToolResultError(err.Error()), nil
180+
}
181+
order, err := OptionalParam[string](request, "order")
182+
if err != nil {
183+
return mcp.NewToolResultError(err.Error()), nil
184+
}
185+
pagination, err := OptionalPaginationParams(request)
186+
if err != nil {
187+
return mcp.NewToolResultError(err.Error()), nil
188+
}
210189

211-
opts := &github.SearchOptions{
212-
Sort: sort,
213-
Order: order,
214-
ListOptions: github.ListOptions{
215-
PerPage: pagination.perPage,
216-
Page: pagination.page,
217-
},
218-
}
190+
opts := &github.SearchOptions{
191+
Sort: sort,
192+
Order: order,
193+
ListOptions: github.ListOptions{
194+
PerPage: pagination.perPage,
195+
Page: pagination.page,
196+
},
197+
}
219198

220-
client, err := getClient(ctx)
221-
if err != nil {
222-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
223-
}
199+
client, err := getClient(ctx)
200+
if err != nil {
201+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
202+
}
224203

225-
result, resp, err := client.Search.Users(ctx, "type:user "+query, opts)
204+
searchQuery := "type:" + accountType + " " + query
205+
result, resp, err := client.Search.Users(ctx, searchQuery, opts)
206+
if err != nil {
207+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
208+
fmt.Sprintf("failed to search %ss with query '%s'", accountType, query),
209+
resp,
210+
err,
211+
), nil
212+
}
213+
defer func() { _ = resp.Body.Close() }()
214+
215+
if resp.StatusCode != 200 {
216+
body, err := io.ReadAll(resp.Body)
226217
if err != nil {
227-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
228-
fmt.Sprintf("failed to search users with query '%s'", query),
229-
resp,
230-
err,
231-
), nil
218+
return nil, fmt.Errorf("failed to read response body: %w", err)
232219
}
233-
defer func() { _ = resp.Body.Close() }()
220+
return mcp.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil
221+
}
234222

235-
if resp.StatusCode != 200 {
236-
body, err := io.ReadAll(resp.Body)
237-
if err != nil {
238-
return nil, fmt.Errorf("failed to read response body: %w", err)
239-
}
240-
return mcp.NewToolResultError(fmt.Sprintf("failed to search users: %s", string(body))), nil
241-
}
223+
minimalUsers := make([]MinimalUser, 0, len(result.Users))
242224

243-
minimalUsers := make([]MinimalUser, 0, len(result.Users))
244-
for _, user := range result.Users {
245-
mu := MinimalUser{
246-
Login: user.GetLogin(),
247-
ID: user.GetID(),
248-
ProfileURL: user.GetHTMLURL(),
249-
AvatarURL: user.GetAvatarURL(),
225+
for _, user := range result.Users {
226+
if user.Login != nil {
227+
mu := MinimalUser{Login: *user.Login}
228+
if user.ID != nil {
229+
mu.ID = *user.ID
230+
}
231+
if user.HTMLURL != nil {
232+
mu.ProfileURL = *user.HTMLURL
233+
}
234+
if user.AvatarURL != nil {
235+
mu.AvatarURL = *user.AvatarURL
250236
}
251-
252237
minimalUsers = append(minimalUsers, mu)
253238
}
239+
}
240+
minimalResp := &MinimalSearchUsersResult{
241+
TotalCount: result.GetTotal(),
242+
IncompleteResults: result.GetIncompleteResults(),
243+
Items: minimalUsers,
244+
}
245+
if result.Total != nil {
246+
minimalResp.TotalCount = *result.Total
247+
}
248+
if result.IncompleteResults != nil {
249+
minimalResp.IncompleteResults = *result.IncompleteResults
250+
}
254251

255-
minimalResp := MinimalSearchUsersResult{
256-
TotalCount: result.GetTotal(),
257-
IncompleteResults: result.GetIncompleteResults(),
258-
Items: minimalUsers,
259-
}
260-
r, err := json.Marshal(minimalResp)
261-
if err != nil {
262-
return nil, fmt.Errorf("failed to marshal response: %w", err)
263-
}
264-
return mcp.NewToolResultText(string(r)), nil
252+
r, err := json.Marshal(minimalResp)
253+
if err != nil {
254+
return nil, fmt.Errorf("failed to marshal response: %w", err)
265255
}
256+
return mcp.NewToolResultText(string(r)), nil
257+
}
258+
}
259+
260+
// SearchUsers creates a tool to search for GitHub users.
261+
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
262+
return mcp.NewTool("search_users",
263+
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users exclusively")),
264+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
265+
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
266+
ReadOnlyHint: ToBoolPtr(true),
267+
}),
268+
mcp.WithString("query",
269+
mcp.Required(),
270+
mcp.Description("Search query using GitHub users search syntax scoped to type:user"),
271+
),
272+
mcp.WithString("sort",
273+
mcp.Description("Sort field by category"),
274+
mcp.Enum("followers", "repositories", "joined"),
275+
),
276+
mcp.WithString("order",
277+
mcp.Description("Sort order"),
278+
mcp.Enum("asc", "desc"),
279+
),
280+
WithPagination(),
281+
), userOrOrgHandler("user", getClient)
282+
}
283+
284+
// SearchOrgs creates a tool to search for GitHub organizations.
285+
func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
286+
return mcp.NewTool("search_orgs",
287+
mcp.WithDescription(t("TOOL_SEARCH_ORGS_DESCRIPTION", "Search for GitHub organizations exclusively")),
288+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
289+
Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"),
290+
ReadOnlyHint: ToBoolPtr(true),
291+
}),
292+
mcp.WithString("query",
293+
mcp.Required(),
294+
mcp.Description("Search query using GitHub organizations search syntax scoped to type:org"),
295+
),
296+
mcp.WithString("sort",
297+
mcp.Description("Sort field by category"),
298+
mcp.Enum("followers", "repositories", "joined"),
299+
),
300+
mcp.WithString("order",
301+
mcp.Description("Sort order"),
302+
mcp.Enum("asc", "desc"),
303+
),
304+
WithPagination(),
305+
), userOrOrgHandler("org", getClient)
266306
}

pkg/github/search_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,125 @@ func Test_SearchUsers(t *testing.T) {
474474
})
475475
}
476476
}
477+
478+
func Test_SearchOrgs(t *testing.T) {
479+
// Verify tool definition once
480+
mockClient := github.NewClient(nil)
481+
tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper)
482+
483+
assert.Equal(t, "search_orgs", tool.Name)
484+
assert.NotEmpty(t, tool.Description)
485+
assert.Contains(t, tool.InputSchema.Properties, "q")
486+
assert.Contains(t, tool.InputSchema.Properties, "sort")
487+
assert.Contains(t, tool.InputSchema.Properties, "order")
488+
assert.Contains(t, tool.InputSchema.Properties, "perPage")
489+
assert.Contains(t, tool.InputSchema.Properties, "page")
490+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"})
491+
492+
// Setup mock search results
493+
mockSearchResult := &github.UsersSearchResult{
494+
Total: github.Ptr(int(2)),
495+
IncompleteResults: github.Ptr(false),
496+
Users: []*github.User{
497+
{
498+
Login: github.Ptr("org-1"),
499+
ID: github.Ptr(int64(111)),
500+
HTMLURL: github.Ptr("https://github.com/org-1"),
501+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/111?v=4"),
502+
},
503+
{
504+
Login: github.Ptr("org-2"),
505+
ID: github.Ptr(int64(222)),
506+
HTMLURL: github.Ptr("https://github.com/org-2"),
507+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/222?v=4"),
508+
},
509+
},
510+
}
511+
512+
tests := []struct {
513+
name string
514+
mockedClient *http.Client
515+
requestArgs map[string]interface{}
516+
expectError bool
517+
expectedResult *github.UsersSearchResult
518+
expectedErrMsg string
519+
}{
520+
{
521+
name: "successful org search",
522+
mockedClient: mock.NewMockedHTTPClient(
523+
mock.WithRequestMatchHandler(
524+
mock.GetSearchUsers,
525+
expectQueryParams(t, map[string]string{
526+
"q": "type:org github",
527+
"page": "1",
528+
"per_page": "30",
529+
}).andThen(
530+
mockResponse(t, http.StatusOK, mockSearchResult),
531+
),
532+
),
533+
),
534+
requestArgs: map[string]interface{}{
535+
"q": "github",
536+
},
537+
expectError: false,
538+
expectedResult: mockSearchResult,
539+
},
540+
{
541+
name: "org search fails",
542+
mockedClient: mock.NewMockedHTTPClient(
543+
mock.WithRequestMatchHandler(
544+
mock.GetSearchUsers,
545+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
546+
w.WriteHeader(http.StatusBadRequest)
547+
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
548+
}),
549+
),
550+
),
551+
requestArgs: map[string]interface{}{
552+
"q": "invalid:query",
553+
},
554+
expectError: true,
555+
expectedErrMsg: "failed to search orgs",
556+
},
557+
}
558+
559+
for _, tc := range tests {
560+
t.Run(tc.name, func(t *testing.T) {
561+
// Setup client with mock
562+
client := github.NewClient(tc.mockedClient)
563+
_, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper)
564+
565+
// Create call request
566+
request := createMCPRequest(tc.requestArgs)
567+
568+
// Call handler
569+
result, err := handler(context.Background(), request)
570+
571+
// Verify results
572+
if tc.expectError {
573+
require.Error(t, err)
574+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
575+
return
576+
}
577+
578+
require.NoError(t, err)
579+
require.NotNil(t, result)
580+
581+
textContent := getTextResult(t, result)
582+
583+
// Unmarshal and verify the result
584+
var returnedResult MinimalSearchUsersResult
585+
err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
586+
require.NoError(t, err)
587+
assert.Equal(t, *tc.expectedResult.Total, returnedResult.TotalCount)
588+
assert.Equal(t, *tc.expectedResult.IncompleteResults, returnedResult.IncompleteResults)
589+
assert.Len(t, returnedResult.Items, len(tc.expectedResult.Users))
590+
for i, org := range returnedResult.Items {
591+
assert.Equal(t, *tc.expectedResult.Users[i].Login, org.Login)
592+
assert.Equal(t, *tc.expectedResult.Users[i].ID, org.ID)
593+
assert.Equal(t, *tc.expectedResult.Users[i].HTMLURL, org.ProfileURL)
594+
assert.Equal(t, *tc.expectedResult.Users[i].AvatarURL, org.AvatarURL)
595+
}
596+
})
597+
}
598+
}

pkg/github/tools.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG
6464
AddReadTools(
6565
toolsets.NewServerTool(SearchUsers(getClient, t)),
6666
)
67+
orgs := toolsets.NewToolset("orgs", "GitHub Organization related tools").
68+
AddReadTools(
69+
toolsets.NewServerTool(SearchOrgs(getClient, t)),
70+
)
6771
pullRequests := toolsets.NewToolset("pull_requests", "GitHub Pull Request related tools").
6872
AddReadTools(
6973
toolsets.NewServerTool(GetPullRequest(getClient, t)),
@@ -143,6 +147,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG
143147
tsg.AddToolset(contextTools)
144148
tsg.AddToolset(repos)
145149
tsg.AddToolset(issues)
150+
tsg.AddToolset(orgs)
146151
tsg.AddToolset(users)
147152
tsg.AddToolset(pullRequests)
148153
tsg.AddToolset(actions)

0 commit comments

Comments
 (0)