Skip to content

Commit 2583fd0

Browse files
separate org and user search
1 parent 1495115 commit 2583fd0

File tree

3 files changed

+245
-90
lines changed

3 files changed

+245
-90
lines changed

pkg/github/search.go

Lines changed: 118 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -160,104 +160,132 @@ type MinimalSearchUsersResult struct {
160160
}
161161

162162
// SearchUsers creates a tool to search for GitHub users.
163-
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
164-
return mcp.NewTool("search_users",
165-
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")),
166-
mcp.WithToolAnnotation(mcp.ToolAnnotation{
167-
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
168-
ReadOnlyHint: toBoolPtr(true),
169-
}),
170-
mcp.WithString("q",
171-
mcp.Required(),
172-
mcp.Description("Search query using GitHub users search syntax"),
173-
),
174-
mcp.WithString("sort",
175-
mcp.Description("Sort field by category"),
176-
mcp.Enum("followers", "repositories", "joined"),
177-
),
178-
mcp.WithString("order",
179-
mcp.Description("Sort order"),
180-
mcp.Enum("asc", "desc"),
181-
),
182-
WithPagination(),
183-
),
184-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
185-
query, err := requiredParam[string](request, "q")
186-
if err != nil {
187-
return mcp.NewToolResultError(err.Error()), nil
188-
}
189-
sort, err := OptionalParam[string](request, "sort")
190-
if err != nil {
191-
return mcp.NewToolResultError(err.Error()), nil
192-
}
193-
order, err := OptionalParam[string](request, "order")
194-
if err != nil {
195-
return mcp.NewToolResultError(err.Error()), nil
196-
}
197-
pagination, err := OptionalPaginationParams(request)
198-
if err != nil {
199-
return mcp.NewToolResultError(err.Error()), nil
200-
}
163+
func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHandlerFunc {
164+
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
165+
query, err := requiredParam[string](request, "q")
166+
if err != nil {
167+
return mcp.NewToolResultError(err.Error()), nil
168+
}
169+
sort, err := OptionalParam[string](request, "sort")
170+
if err != nil {
171+
return mcp.NewToolResultError(err.Error()), nil
172+
}
173+
order, err := OptionalParam[string](request, "order")
174+
if err != nil {
175+
return mcp.NewToolResultError(err.Error()), nil
176+
}
177+
pagination, err := OptionalPaginationParams(request)
178+
if err != nil {
179+
return mcp.NewToolResultError(err.Error()), nil
180+
}
201181

202-
opts := &github.SearchOptions{
203-
Sort: sort,
204-
Order: order,
205-
ListOptions: github.ListOptions{
206-
PerPage: pagination.perPage,
207-
Page: pagination.page,
208-
},
209-
}
182+
opts := &github.SearchOptions{
183+
Sort: sort,
184+
Order: order,
185+
ListOptions: github.ListOptions{
186+
PerPage: pagination.perPage,
187+
Page: pagination.page,
188+
},
189+
}
210190

211-
client, err := getClient(ctx)
212-
if err != nil {
213-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
214-
}
191+
client, err := getClient(ctx)
192+
if err != nil {
193+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
194+
}
215195

216-
result, resp, err := client.Search.Users(ctx, "type:user "+query, opts)
196+
searchQuery := "type:" + accountType + " " + query
197+
result, resp, err := client.Search.Users(ctx, searchQuery, opts)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to search %ss: %w", accountType, err)
200+
}
201+
defer func() { _ = resp.Body.Close() }()
202+
203+
if resp.StatusCode != 200 {
204+
body, err := io.ReadAll(resp.Body)
217205
if err != nil {
218-
return nil, fmt.Errorf("failed to search users: %w", err)
206+
return nil, fmt.Errorf("failed to read response body: %w", err)
219207
}
220-
defer func() { _ = resp.Body.Close() }()
208+
return mcp.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil
209+
}
221210

222-
if resp.StatusCode != 200 {
223-
body, err := io.ReadAll(resp.Body)
224-
if err != nil {
225-
return nil, fmt.Errorf("failed to read response body: %w", err)
211+
minimalUsers := make([]MinimalUser, 0, len(result.Users))
212+
for _, user := range result.Users {
213+
if user.Login != nil {
214+
mu := MinimalUser{Login: *user.Login}
215+
if user.ID != nil {
216+
mu.ID = *user.ID
226217
}
227-
return mcp.NewToolResultError(fmt.Sprintf("failed to search users: %s", string(body))), nil
228-
}
229-
230-
minimalUsers := make([]MinimalUser, 0, len(result.Users))
231-
for _, user := range result.Users {
232-
if user.Login != nil {
233-
mu := MinimalUser{Login: *user.Login}
234-
if user.ID != nil {
235-
mu.ID = *user.ID
236-
}
237-
if user.HTMLURL != nil {
238-
mu.ProfileURL = *user.HTMLURL
239-
}
240-
if user.AvatarURL != nil {
241-
mu.AvatarURL = *user.AvatarURL
242-
}
243-
minimalUsers = append(minimalUsers, mu)
218+
if user.HTMLURL != nil {
219+
mu.ProfileURL = *user.HTMLURL
244220
}
221+
if user.AvatarURL != nil {
222+
mu.AvatarURL = *user.AvatarURL
223+
}
224+
minimalUsers = append(minimalUsers, mu)
245225
}
246-
minimalResp := MinimalSearchUsersResult{
247-
TotalCount: result.GetTotal(),
248-
IncompleteResults: result.GetIncompleteResults(),
249-
Items: minimalUsers,
250-
}
251-
if result.Total != nil {
252-
minimalResp.TotalCount = *result.Total
253-
}
254-
if result.IncompleteResults != nil {
255-
minimalResp.IncompleteResults = *result.IncompleteResults
256-
}
257-
r, err := json.Marshal(minimalResp)
258-
if err != nil {
259-
return nil, fmt.Errorf("failed to marshal response: %w", err)
260-
}
261-
return mcp.NewToolResultText(string(r)), nil
262226
}
227+
minimalResp := &MinimalSearchUsersResult{
228+
TotalCount: result.GetTotal(),
229+
IncompleteResults: result.GetIncompleteResults(),
230+
Items: minimalUsers,
231+
}
232+
if result.Total != nil {
233+
minimalResp.TotalCount = *result.Total
234+
}
235+
if result.IncompleteResults != nil {
236+
minimalResp.IncompleteResults = *result.IncompleteResults
237+
}
238+
r, err := json.Marshal(minimalResp)
239+
if err != nil {
240+
return nil, fmt.Errorf("failed to marshal response: %w", err)
241+
}
242+
return mcp.NewToolResultText(string(r)), nil
243+
}
244+
}
245+
246+
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
247+
return mcp.NewTool("search_users",
248+
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users exlusively")),
249+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
250+
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
251+
ReadOnlyHint: toBoolPtr(true),
252+
}),
253+
mcp.WithString("q",
254+
mcp.Required(),
255+
mcp.Description("Search query using GitHub users search syntax scoped to type:user"),
256+
),
257+
mcp.WithString("sort",
258+
mcp.Description("Sort field by category"),
259+
mcp.Enum("followers", "repositories", "joined"),
260+
),
261+
mcp.WithString("order",
262+
mcp.Description("Sort order"),
263+
mcp.Enum("asc", "desc"),
264+
),
265+
WithPagination(),
266+
), userOrOrgHandler("user", getClient)
267+
}
268+
269+
// SearchOrgs creates a tool to search for GitHub organizations.
270+
func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
271+
return mcp.NewTool("search_orgs",
272+
mcp.WithDescription(t("TOOL_SEARCH_ORGS_DESCRIPTION", "Search for GitHub organizations exclusively")),
273+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
274+
Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"),
275+
ReadOnlyHint: toBoolPtr(true),
276+
}),
277+
mcp.WithString("q",
278+
mcp.Required(),
279+
mcp.Description("Search query using GitHub organizations search syntax scoped to type:org"),
280+
),
281+
mcp.WithString("sort",
282+
mcp.Description("Sort field by category"),
283+
mcp.Enum("followers", "repositories", "joined"),
284+
),
285+
mcp.WithString("order",
286+
mcp.Description("Sort order"),
287+
mcp.Enum("asc", "desc"),
288+
),
289+
WithPagination(),
290+
), userOrOrgHandler("org", getClient)
263291
}

pkg/github/search_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,125 @@ func Test_SearchUsers(t *testing.T) {
461461
})
462462
}
463463
}
464+
465+
func Test_SearchOrgs(t *testing.T) {
466+
// Verify tool definition once
467+
mockClient := github.NewClient(nil)
468+
tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper)
469+
470+
assert.Equal(t, "search_orgs", tool.Name)
471+
assert.NotEmpty(t, tool.Description)
472+
assert.Contains(t, tool.InputSchema.Properties, "q")
473+
assert.Contains(t, tool.InputSchema.Properties, "sort")
474+
assert.Contains(t, tool.InputSchema.Properties, "order")
475+
assert.Contains(t, tool.InputSchema.Properties, "perPage")
476+
assert.Contains(t, tool.InputSchema.Properties, "page")
477+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"})
478+
479+
// Setup mock search results
480+
mockSearchResult := &github.UsersSearchResult{
481+
Total: github.Ptr(int(2)),
482+
IncompleteResults: github.Ptr(false),
483+
Users: []*github.User{
484+
{
485+
Login: github.Ptr("org-1"),
486+
ID: github.Ptr(int64(111)),
487+
HTMLURL: github.Ptr("https://github.com/org-1"),
488+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/111?v=4"),
489+
},
490+
{
491+
Login: github.Ptr("org-2"),
492+
ID: github.Ptr(int64(222)),
493+
HTMLURL: github.Ptr("https://github.com/org-2"),
494+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/222?v=4"),
495+
},
496+
},
497+
}
498+
499+
tests := []struct {
500+
name string
501+
mockedClient *http.Client
502+
requestArgs map[string]interface{}
503+
expectError bool
504+
expectedResult *github.UsersSearchResult
505+
expectedErrMsg string
506+
}{
507+
{
508+
name: "successful org search",
509+
mockedClient: mock.NewMockedHTTPClient(
510+
mock.WithRequestMatchHandler(
511+
mock.GetSearchUsers,
512+
expectQueryParams(t, map[string]string{
513+
"q": "type:org github",
514+
"page": "1",
515+
"per_page": "30",
516+
}).andThen(
517+
mockResponse(t, http.StatusOK, mockSearchResult),
518+
),
519+
),
520+
),
521+
requestArgs: map[string]interface{}{
522+
"q": "github",
523+
},
524+
expectError: false,
525+
expectedResult: mockSearchResult,
526+
},
527+
{
528+
name: "org search fails",
529+
mockedClient: mock.NewMockedHTTPClient(
530+
mock.WithRequestMatchHandler(
531+
mock.GetSearchUsers,
532+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
533+
w.WriteHeader(http.StatusBadRequest)
534+
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
535+
}),
536+
),
537+
),
538+
requestArgs: map[string]interface{}{
539+
"q": "invalid:query",
540+
},
541+
expectError: true,
542+
expectedErrMsg: "failed to search orgs",
543+
},
544+
}
545+
546+
for _, tc := range tests {
547+
t.Run(tc.name, func(t *testing.T) {
548+
// Setup client with mock
549+
client := github.NewClient(tc.mockedClient)
550+
_, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper)
551+
552+
// Create call request
553+
request := createMCPRequest(tc.requestArgs)
554+
555+
// Call handler
556+
result, err := handler(context.Background(), request)
557+
558+
// Verify results
559+
if tc.expectError {
560+
require.Error(t, err)
561+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
562+
return
563+
}
564+
565+
require.NoError(t, err)
566+
require.NotNil(t, result)
567+
568+
textContent := getTextResult(t, result)
569+
570+
// Unmarshal and verify the result
571+
var returnedResult MinimalSearchUsersResult
572+
err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
573+
require.NoError(t, err)
574+
assert.Equal(t, *tc.expectedResult.Total, returnedResult.TotalCount)
575+
assert.Equal(t, *tc.expectedResult.IncompleteResults, returnedResult.IncompleteResults)
576+
assert.Len(t, returnedResult.Items, len(tc.expectedResult.Users))
577+
for i, org := range returnedResult.Items {
578+
assert.Equal(t, *tc.expectedResult.Users[i].Login, org.Login)
579+
assert.Equal(t, *tc.expectedResult.Users[i].ID, org.ID)
580+
assert.Equal(t, *tc.expectedResult.Users[i].HTMLURL, org.ProfileURL)
581+
assert.Equal(t, *tc.expectedResult.Users[i].AvatarURL, org.AvatarURL)
582+
}
583+
})
584+
}
585+
}

pkg/github/tools.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
5757
AddReadTools(
5858
toolsets.NewServerTool(SearchUsers(getClient, t)),
5959
)
60+
orgs := toolsets.NewToolset("orgs", "GitHub Organization related tools").
61+
AddReadTools(
62+
toolsets.NewServerTool(SearchOrgs(getClient, t)),
63+
)
6064
pullRequests := toolsets.NewToolset("pull_requests", "GitHub Pull Request related tools").
6165
AddReadTools(
6266
toolsets.NewServerTool(GetPullRequest(getClient, t)),
@@ -111,6 +115,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
111115
tsg.AddToolset(repos)
112116
tsg.AddToolset(issues)
113117
tsg.AddToolset(users)
118+
tsg.AddToolset(orgs)
114119
tsg.AddToolset(pullRequests)
115120
tsg.AddToolset(codeSecurity)
116121
tsg.AddToolset(secretProtection)

0 commit comments

Comments
 (0)