Переглянути джерело

jtp: made async instead of sequential

Benton Edmondson 2 роки тому
батько
коміт
297f20598c
2 змінених файлів з 96 додано та 57 видалено
  1. 83 55
      jtp/jtp.go
  2. 13 2
      jtp/jtp_test.go

+ 83 - 55
jtp/jtp.go

@@ -11,6 +11,7 @@ import (
 	"fmt"
 	"strings"
 	"encoding/json"
+	. "mimicry/preamble"
 )
 
 // TODO: parseMediaType should probably return an error if the mediaType is invalid
@@ -53,77 +54,104 @@ var toleratedTypes = []string{
 	maxRedirects
 		the maximum number of redirects to take
 */
-// TODO: the number of redirects must be limited
-func Get(link *url.URL, maxRedirects uint) (map[string]any, error) {
+func Get(link *url.URL, maxRedirects uint) <-chan *Result[map[string]any] {
 
-	if link.Scheme != "https" {
-		return nil, errors.New(link.Scheme + " is not supported in requests, only https")
-	}
+	channel := make(chan *Result[map[string]any], 1)
 
-	port := link.Port()
-	if port == "" {
-		port = "443"
-	}
+	go func() {
+		if link.Scheme != "https" {
+			channel <- Err[map[string]any](errors.New(link.Scheme + " is not supported in requests, only https"))
+			return
+		}
 
-	hostport := net.JoinHostPort(link.Hostname(), port)
+		port := link.Port()
+		if port == "" {
+			port = "443"
+		}
 
-	connection, err := dialer.Dial("tcp", hostport)
-	if err != nil {
-		return nil, err
-	}
-	defer connection.Close()
-
-	_, err = connection.Write([]byte(
-		"GET " + link.RequestURI() + " HTTP/1.0\r\n" +
-		"Host: " + link.Host + "\r\n" +
-		"Accept: " + acceptHeader + "\r\n" +
-		"Accept-Encoding: identity\r\n" +
-		"\r\n",
-	))
-	if err != nil {
-		return nil, err
-	}
+		hostport := net.JoinHostPort(link.Hostname(), port)
 
-	buf := bufio.NewReader(connection)
-	statusLine, err := buf.ReadString('\n')
-	if err != nil {
-		return nil, fmt.Errorf("Encountered error while reading status line of HTTP response: %w", err)
-	}
+		connection, err := dialer.Dial("tcp", hostport)
+		if err != nil {
+			channel <- Err[map[string]any](err)
+			return
+		}
 
-	status, err := parseStatusLine(statusLine)
-	if err != nil {
-		return nil, err
-	}
+		_, err = connection.Write([]byte(
+			"GET " + link.RequestURI() + " HTTP/1.0\r\n" +
+			"Host: " + link.Host + "\r\n" +
+			"Accept: " + acceptHeader + "\r\n" +
+			"Accept-Encoding: identity\r\n" +
+			"\r\n",
+		))
+		if err != nil {
+			channel <- Err[map[string]any](err, connection.Close())
+			return
+		}
 
-	if strings.HasPrefix(status, "3") {
-		location, err := findLocation(buf, link)
+		buf := bufio.NewReader(connection)
+		statusLine, err := buf.ReadString('\n')
 		if err != nil {
-			return nil, err
+			channel <- Err[map[string]any](
+				fmt.Errorf("failed to parse HTTP status line: %w", err),
+				connection.Close(),
+			)
+			return
 		}
 
-		if maxRedirects == 0 {
-			return nil, errors.New("Received " + status + " but max redirects has already been reached")
+		status, err := parseStatusLine(statusLine)
+		if err != nil {
+			channel <- Err[map[string]any](err, connection.Close())
+			return
 		}
 
-		return Get(location, maxRedirects - 1)
-	}
+		if strings.HasPrefix(status, "3") {
+			location, err := findLocation(buf, link)
+			if err != nil {
+				channel <- Err[map[string]any](err, connection.Close())
+				return
+			}
+
+			if maxRedirects == 0 {
+				channel <- Err[map[string]any](
+					errors.New("Received " + status + " but max redirects has already been reached"),
+					connection.Close(),
+				)
+				return
+			}
+
+			channel <- <-Get(location, maxRedirects - 1)
+			return
+		}
 
-	if status != "200" && status != "201" && status != "202" && status != "203" {
-		return nil, errors.New("Received invalid status " + status)
-	}
+		if status != "200" && status != "201" && status != "202" && status != "203" {
+			channel <- Err[map[string]any](errors.New("Received invalid status " + status))
+			return
+		}
 
-	err = validateHeaders(buf)
-	if err != nil {
-		return nil, err
-	}
+		err = validateHeaders(buf)
+		if err != nil {
+			channel <- Err[map[string]any](err)
+			return
+		}
 
-	var dictionary map[string]any
-	err = json.NewDecoder(buf).Decode(&dictionary)
-	if err != nil {
-		return nil, err
-	}
+		var dictionary map[string]any
+		err = json.NewDecoder(buf).Decode(&dictionary)
+		if err != nil {
+			channel <- Err[map[string]any](fmt.Errorf("failed to parse JSON: %w", err))
+			return
+		}
+
+		err = connection.Close()
+		if err != nil {
+			channel <- Err[map[string]any](err)
+			return
+		}
+
+		channel <- Ok(dictionary)
+	}()
 
-	return dictionary, nil
+	return channel
 }
 
 func ParseMediaType(text string) (MediaType, error) {

+ 13 - 2
jtp/jtp_test.go

@@ -6,6 +6,7 @@ import (
 	"net/url"
 	"encoding/json"
 	"os"
+	. "mimicry/preamble"
 )
 
 func TestStatusLineNoInfo(t *testing.T) {
@@ -25,12 +26,22 @@ func TestBasic(t *testing.T) {
 		panic(err)
 	}
 
-	dict, err := Get(link, 20)
+	dicts := AwaitAll(Get(link, 20), Get(link, 20))
+
+	if dicts[0].Err != nil {
+		panic(dicts[0].Err)
+	}
+
+	if dicts[1].Err != nil {
+		panic(dicts[1].Err)
+	}
+
+	err = json.NewEncoder(os.Stdout).Encode(dicts[0].Ok)
 	if err != nil {
 		panic(err)
 	}
 
-	err = json.NewEncoder(os.Stdout).Encode(dict)
+	err = json.NewEncoder(os.Stdout).Encode(dicts[1].Ok)
 	if err != nil {
 		panic(err)
 	}