Explorar o código

jtp: remove internal goroutine

Benton Edmondson %!s(int64=2) %!d(string=hai) anos
pai
achega
764c6058c9
Modificáronse 2 ficheiros con 93 adicións e 112 borrados
  1. 70 103
      jtp/jtp.go
  2. 23 9
      jtp/jtp_test.go

+ 70 - 103
jtp/jtp.go

@@ -11,7 +11,6 @@ import (
 	"fmt"
 	"strings"
 	"encoding/json"
-	. "mimicry/preamble"
 )
 
 // TODO: parseMediaType should probably return an error if the mediaType is invalid
@@ -33,15 +32,6 @@ var statusLineRegexp = regexp.MustCompile(`^HTTP/1\.[0-9] ([0-9]{3}).*\n$`)
 var contentTypeRegexp = regexp.MustCompile(`^(?i:content-type):[ \t\r]*(.*?)[ \t\r]*\n$`)
 var locationRegexp = regexp.MustCompile(`^(?i:location):[ \t\r]*(.*?)[ \t\r]*\n$`)
 
-var acceptHeader = `application/activity+json,` +
-	`application/ld+json; profile="https://www.w3.org/ns/activitystreams"`
-	
-var toleratedTypes = []string{
-	"application/activity+json",
-	"application/ld+json",
-	"application/json",
-}
-
 /*
 	I send an HTTP/1.0 request to ensure the server doesn't respond
 	with chunked transfer encoding.
@@ -54,116 +44,93 @@ var toleratedTypes = []string{
 	maxRedirects
 		the maximum number of redirects to take
 */
-func Get(link *url.URL, maxRedirects uint) <-chan *Result[map[string]any] {
+func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (map[string]any, error) {
+	if link.Scheme != "https" {
+		return nil, errors.New(link.Scheme + " is not supported in requests, only https")
+	}
 
-	channel := make(chan *Result[map[string]any], 1)
+	port := link.Port()
+	if port == "" {
+		port = "443"
+	}
 
-	go func() {
-		if link.Scheme != "https" {
-			channel <- Err[map[string]any](errors.New(link.Scheme + " is not supported in requests, only https"))
-			return
-		}
+	hostport := net.JoinHostPort(link.Hostname(), port)
 
-		port := link.Port()
-		if port == "" {
-			port = "443"
-		}
-
-		hostport := net.JoinHostPort(link.Hostname(), port)
+	connection, err := dialer.Dial("tcp", hostport)
+	if err != nil {
+		return nil, err
+	}
 
-		connection, err := dialer.Dial("tcp", hostport)
-		if err != nil {
-			channel <- Err[map[string]any](err)
-			return
-		}
+	_, err = connection.Write([]byte(
+		"GET " + link.RequestURI() + " HTTP/1.0\r\n" +
+		"Host: " + link.Host + "\r\n" +
+		"Accept: " + accept + "\r\n" +
+		"Accept-Encoding: identity\r\n" +
+		"\r\n",
+	))
+	if err != nil {
+		return nil, errors.Join(err, connection.Close())
+	}
 
-		_, err = connection.Write([]byte(
-			"GET " + link.RequestURI() + " HTTP/1.0\r\n" +
-			"Host: " + link.Host + "\r\n" +
-			"Accept: " + acceptHeader + "\r\n" +
-			"Accept-Encoding: identity\r\n" +
-			"\r\n",
-		))
-		if err != nil {
-			channel <- Err[map[string]any](err, connection.Close())
-			return
-		}
+	buf := bufio.NewReader(connection)
+	statusLine, err := buf.ReadString('\n')
+	if err != nil {
+		return nil, errors.Join(
+			fmt.Errorf("failed to parse HTTP status line: %w", err),
+			connection.Close(),
+		)
+	}
 
-		buf := bufio.NewReader(connection)
-		statusLine, err := buf.ReadString('\n')
-		if err != nil {
-			channel <- Err[map[string]any](
-				fmt.Errorf("failed to parse HTTP status line: %w", err),
-				connection.Close(),
-			)
-			return
-		}
+	status, err := parseStatusLine(statusLine)
+	if err != nil {
+		return nil, errors.Join(err, connection.Close())
+	}
 
-		status, err := parseStatusLine(statusLine)
+	if strings.HasPrefix(status, "3") {
+		location, err := findLocation(buf, link)
 		if err != nil {
-			channel <- Err[map[string]any](err, connection.Close())
-			return
-		}
-
-		if strings.HasPrefix(status, "3") {
-			location, err := findLocation(buf, link)
-			if err != nil {
-				channel <- Err[map[string]any](err, connection.Close())
-				return
-			}
-
-			if maxRedirects == 0 {
-				channel <- Err[map[string]any](
-					errors.New("Received " + status + " but max redirects has already been reached"),
-					connection.Close(),
-				)
-				return
-			}
-
-			if err := connection.Close(); err != nil {
-				channel <- Err[map[string]any](err)
-				return
-			}
-			channel <- <-Get(location, maxRedirects - 1)
-			return
+			return nil, errors.Join(err, connection.Close())
 		}
 
-		if status != "200" && status != "201" && status != "202" && status != "203" {
-			channel <- Err[map[string]any](
-				errors.New("Received invalid status " + status),
+		if maxRedirects == 0 {
+			return nil, errors.Join(
+				errors.New("Received " + status + " but max redirects has already been reached"),
 				connection.Close(),
 			)
-			return
 		}
 
-		err = validateHeaders(buf)
-		if err != nil {
-			channel <- Err[map[string]any](
-				err,
-				connection.Close(),
-			)
-			return
+		if err := connection.Close(); err != nil {
+			return nil, err
 		}
+		return Get(location, accept, tolerated, maxRedirects - 1)
+	}
 
-		var dictionary map[string]any
-		err = json.NewDecoder(buf).Decode(&dictionary)
-		if err != nil {
-			channel <- Err[map[string]any](
-				fmt.Errorf("failed to parse JSON: %w", err),
-				connection.Close(),
-			)
-			return
-		}
+	if status != "200" && status != "201" && status != "202" && status != "203" {
+		return nil, errors.Join(
+			errors.New("received invalid status " + status),
+			connection.Close(),
+		)
+	}
 
-		if err := connection.Close(); err != nil {
-			channel <- Err[map[string]any](err)
-			return
-		}
+	err = validateHeaders(buf, tolerated)
+	if err != nil {
+		return nil, errors.Join(err, connection.Close())
+	}
 
-		channel <- Ok(dictionary)
-	}()
+	var dictionary map[string]any
+	err = json.NewDecoder(buf).Decode(&dictionary)
+	if err != nil {
+		return nil, errors.Join(
+			fmt.Errorf("failed to parse JSON: %w", err),
+			connection.Close(),
+		)
+	}
+
+	if err := connection.Close(); err != nil {
+		return nil, err
+	}
 
-	return channel
+	return dictionary, nil
 }
 
 func ParseMediaType(text string) (MediaType, error) {
@@ -220,7 +187,7 @@ func parseLocation(text string, baseLink *url.URL) (link *url.URL, isLocationLin
 	return baseLink.ResolveReference(reference), true, nil
 }
 
-func validateHeaders(buf *bufio.Reader) error {
+func validateHeaders(buf *bufio.Reader, tolerated []string) error {
 	contentTypeValidated := false
 	for {
 		line, err := buf.ReadString('\n')
@@ -240,7 +207,7 @@ func validateHeaders(buf *bufio.Reader) error {
 			continue
 		}
 
-		if slices.Contains(toleratedTypes, mediaType.Full) {
+		if slices.Contains(tolerated, mediaType.Full) {
 			contentTypeValidated = true
 		} else {
 			return errors.New("Response contains invalid content type " + mediaType.Full)

+ 23 - 9
jtp/jtp_test.go

@@ -6,7 +6,7 @@ import (
 	"net/url"
 	"encoding/json"
 	"os"
-	. "mimicry/preamble"
+	"sync"
 )
 
 func TestStatusLineNoInfo(t *testing.T) {
@@ -21,27 +21,41 @@ func TestStatusLineNoInfo(t *testing.T) {
 // TODO: put this behind an --online flag or figure out
 // how to nicely do offline tests
 func TestBasic(t *testing.T) {
+	accept := "application/activity+json"
+	tolerated := []string{"application/json"}
+
 	link, err := url.Parse("https://httpbin.org/redirect/20")
 	if err != nil {
 		panic(err)
 	}
 
-	dicts := AwaitAll(Get(link, 20), Get(link, 20))
-
-	if dicts[0].Err != nil {
-		panic(dicts[0].Err)
+	var dict1, dict2 map[string]any
+	var err1, err2 error
+	var wg sync.WaitGroup; wg.Add(2); {
+		go func() {
+			dict1, err1 = Get(link, accept, tolerated, 20)
+			wg.Done()
+		}()
+		go func() {
+			dict2, err2 = Get(link, accept, tolerated, 20)
+			wg.Done()
+		}()
+	}; wg.Wait()
+
+	if err1 != nil {
+		panic(err1)
 	}
 
-	if dicts[1].Err != nil {
-		panic(dicts[1].Err)
+	if err2 != nil {
+		panic(err2)
 	}
 
-	err = json.NewEncoder(os.Stdout).Encode(dicts[0].Ok)
+	err = json.NewEncoder(os.Stdout).Encode(dict1)
 	if err != nil {
 		panic(err)
 	}
 
-	err = json.NewEncoder(os.Stdout).Encode(dicts[1].Ok)
+	err = json.NewEncoder(os.Stdout).Encode(dict2)
 	if err != nil {
 		panic(err)
 	}