Explorar o código

300 redirects are taken into account when determining whether to refetch

Benton Edmondson hai 1 ano
pai
achega
bced6ea1bd
Modificáronse 3 ficheiros con 43 adicións e 48 borrados
  1. 5 6
      client/client.go
  2. 14 14
      jtp/jtp.go
  3. 24 28
      jtp/jtp_test.go

+ 5 - 6
client/client.go

@@ -21,7 +21,7 @@ func FetchUnknown(input any, source *url.URL) (object.Object, *url.URL, error) {
 		if err != nil {
 			return nil, nil, err
 		}
-		obj, err = FetchURL(url)
+		obj, source, err = FetchURL(url)
 		if err != nil { return nil, nil, err }
 	case map[string]any:
 		obj = object.Object(narrowed)
@@ -39,10 +39,10 @@ func FetchUnknown(input any, source *url.URL) (object.Object, *url.URL, error) {
 
 	if id != nil {
 		if source == nil {
-			obj, err = FetchURL(id)
+			obj, source, err = FetchURL(id)
 			if err != nil { return nil, nil, err }
 		} else if (source.Host != id.Host) || len(obj) <= 2 {
-			obj, err = FetchURL(id)
+			obj, source, err = FetchURL(id)
 			if err != nil { return nil, nil, err }
 		}
 	}
@@ -53,7 +53,7 @@ func FetchUnknown(input any, source *url.URL) (object.Object, *url.URL, error) {
 	return obj, id, err
 }
 
-func FetchURL(link *url.URL) (object.Object, error) {
+func FetchURL(link *url.URL) (object.Object, *url.URL, error) {
 	return jtp.Get(
 			link,
 			`application/activity+json,` +
@@ -94,14 +94,13 @@ func ResolveWebfinger(username string) (string, error) {
 		}).Encode(),
 	}
 
-	json, err := jtp.Get(link, "application/jrd+json", []string{
+	json, _, err := jtp.Get(link, "application/jrd+json", []string{
 		"application/jrd+json",
 		"application/json",
 	}, MAX_REDIRECTS)
 	if err != nil {
 		return "", err
 	}
-
 	response := object.Object(json)
 
 	jrdLinks, err := response.GetList("links")

+ 14 - 14
jtp/jtp.go

@@ -33,9 +33,9 @@ var locationRegexp = regexp.MustCompile(`^(?i:location):[ \t\r]*(.*?)[ \t\r]*\n$
 	maxRedirects
 		the maximum number of redirects to take
 */
-func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (map[string]any, error) {
+func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (map[string]any, *url.URL, error) {
 	if link.Scheme != "https" {
-		return nil, errors.New(link.Scheme + " is not supported in requests, only https")
+		return nil, nil, errors.New(link.Scheme + " is not supported in requests, only https")
 	}
 
 	port := link.Port()
@@ -48,7 +48,7 @@ func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (m
 
 	connection, err := dialer.Dial("tcp", hostport)
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 
 	_, err = connection.Write([]byte(
@@ -58,13 +58,13 @@ func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (m
 		"\r\n",
 	))
 	if err != nil {
-		return nil, errors.Join(err, connection.Close())
+		return nil, nil, errors.Join(err, connection.Close())
 	}
 
 	buf := bufio.NewReader(connection)
 	statusLine, err := buf.ReadString('\n')
 	if err != nil {
-		return nil, errors.Join(
+		return nil, nil, errors.Join(
 			fmt.Errorf("failed to parse HTTP status line: %w", err),
 			connection.Close(),
 		)
@@ -72,30 +72,30 @@ func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (m
 
 	status, err := parseStatusLine(statusLine)
 	if err != nil {
-		return nil, errors.Join(err, connection.Close())
+		return nil, nil, errors.Join(err, connection.Close())
 	}
 
 	if strings.HasPrefix(status, "3") {
 		location, err := findLocation(buf, link)
 		if err != nil {
-			return nil, errors.Join(err, connection.Close())
+			return nil, nil, errors.Join(err, connection.Close())
 		}
 
 		if maxRedirects == 0 {
-			return nil, errors.Join(
+			return nil, nil, errors.Join(
 				errors.New("Received " + status + " but max redirects has already been reached"),
 				connection.Close(),
 			)
 		}
 
 		if err := connection.Close(); err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 		return Get(location, accept, tolerated, maxRedirects - 1)
 	}
 
 	if status != "200" && status != "201" && status != "202" && status != "203" {
-		return nil, errors.Join(
+		return nil, nil, errors.Join(
 			errors.New("received invalid status " + status),
 			connection.Close(),
 		)
@@ -103,23 +103,23 @@ func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (m
 
 	err = validateHeaders(buf, tolerated)
 	if err != nil {
-		return nil, errors.Join(err, connection.Close())
+		return nil, nil, errors.Join(err, connection.Close())
 	}
 
 	var dictionary map[string]any
 	err = json.NewDecoder(buf).Decode(&dictionary)
 	if err != nil {
-		return nil, errors.Join(
+		return nil, nil, errors.Join(
 			fmt.Errorf("failed to parse JSON: %w", err),
 			connection.Close(),
 		)
 	}
 
 	if err := connection.Close(); err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 
-	return dictionary, nil
+	return dictionary, link, nil
 }
 
 func parseStatusLine(text string) (string, error) {

+ 24 - 28
jtp/jtp_test.go

@@ -4,9 +4,6 @@ import (
 	"testing"
 	"mimicry/util"
 	"net/url"
-	"encoding/json"
-	"os"
-	"sync"
 )
 
 func TestStatusLineNoInfo(t *testing.T) {
@@ -20,43 +17,42 @@ func TestStatusLineNoInfo(t *testing.T) {
 
 // TODO: put this behind an --online flag or figure out
 // how to nicely do offline tests
-func TestBasic(t *testing.T) {
+func TestRedirect(t *testing.T) {
 	accept := "application/activity+json"
 	tolerated := []string{"application/json"}
 
-	link, err := url.Parse("https://httpbin.org/redirect/20")
+	link, err := url.Parse("https://httpbin.org/redirect/5")
 	if err != nil {
-		panic(err)
+		t.Fatalf("invalid url literal: %s", err)
 	}
 
-	var dict1, dict2 map[string]any
-	var err1, err2 error
-	var wg sync.WaitGroup; wg.Add(2); {
-		go func() {
-			dict1, err1 = Get(link, accept, tolerated, 20)
-			wg.Done()
-		}()
-		go func() {
-			dict2, err2 = Get(link, accept, tolerated, 20)
-			wg.Done()
-		}()
-	}; wg.Wait()
-
-	if err1 != nil {
-		panic(err1)
+	_, actualLink, err := Get(link, accept, tolerated, 5)
+
+	if err != nil {
+		t.Fatalf("failed to preform request: %s", err)
 	}
 
-	if err2 != nil {
-		panic(err2)
+	if link.String() == actualLink.String() {
+		t.Fatalf("failed to return the underlying url redirected to by %s", link.String())
 	}
+}
 
-	err = json.NewEncoder(os.Stdout).Encode(dict1)
+func TestBasic(t *testing.T) {
+	accept := "application/activity+json"
+	tolerated := []string{"application/json"}
+
+	link, err := url.Parse("https://httpbin.org/get")
 	if err != nil {
-		panic(err)
+		t.Fatalf("invalid url literal: %s", err)
 	}
 
-	err = json.NewEncoder(os.Stdout).Encode(dict2)
+	_, actualLink, err := Get(link, accept, tolerated, 20)
+
 	if err != nil {
-		panic(err)
+		t.Fatalf("failed to preform request: %s", err)
 	}
-}
+
+	if link.String() != actualLink.String() {
+		t.Fatalf("underlying url %s should match request url %s", actualLink.String(), link.String())
+	}
+}