diff --git a/provider/activitypub/mastodon/client.go b/provider/activitypub/mastodon/client.go index 529b477c..db18cda5 100644 --- a/provider/activitypub/mastodon/client.go +++ b/provider/activitypub/mastodon/client.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/url" "strconv" @@ -159,30 +158,39 @@ func NewClient(ctx context.Context, endpoint string, relayList []string, errorCh return c, nil } -// parseEndpoint splits a URL endpoint into its hostname and port components. -func parseEndpoint(endpoint string) (hostname string, port int64, err error) { - parsedHostname := strings.TrimPrefix(endpoint, "https://") - - zap.L().Info("Parsed hostname", zap.String("parsedHostname", parsedHostname)) +// parseEndpoint splits a URL endpoint into its domain (with protocol) and port components. +func parseEndpoint(endpoint string) (domain string, port int64, err error) { + // check if the endpoint starts with 'https://' + if !strings.HasPrefix(endpoint, "https://") { + return "", 0, fmt.Errorf("invalid endpoint format: must start with 'https://' in '%s'", endpoint) + } - // Check if the host contains a port - if !strings.Contains(parsedHostname, ":") { - return "", 0, fmt.Errorf("invalid host:port format: missing port in '%s'", parsedHostname) + // check if the endpoint contains a colon + if !strings.Contains(endpoint, ":") { + return "", 0, fmt.Errorf("invalid host:port format: missing port in '%s'", endpoint) } - // Split the hostname and port - parsedHostname, hostPortInStr, err := net.SplitHostPort(parsedHostname) - if err != nil { - return "", 0, fmt.Errorf("invalid host:port format: %w", err) + // split the endpoint into domain and port + parts := strings.Split(endpoint, ":") + // https: + domain + port + if len(parts) != 3 { + return "", 0, fmt.Errorf("invalid endpoint format in '%s'", endpoint) } - // Convert the port to int64 - parsedPort, err := strconv.ParseInt(hostPortInStr, 10, 64) + // reconstruct the protocol header and domain + domain = parts[0] + ":" + parts[1] + + // parse the port + parsedPort, err := strconv.ParseInt(parts[2], 10, 64) if err != nil { return "", 0, fmt.Errorf("invalid port value: %w", err) } - return parsedHostname, parsedPort, nil + zap.L().Info("Parsed endpoint", + zap.String("domain", domain), + zap.Int64("port", parsedPort)) + + return domain, parsedPort, nil } // initializeServer initializes an echo http server