Browse Source

jtp: make ParseMediaType return a pointer

Benton Edmondson 2 years ago
parent
commit
aff54bd700
1 changed files with 34 additions and 29 deletions
  1. 34 29
      jtp/jtp.go

+ 34 - 29
jtp/jtp.go

@@ -2,7 +2,6 @@ package jtp
 
 import (
 	"regexp"
-	"golang.org/x/exp/slices"
 	"errors"
 	"crypto/tls"
 	"net"
@@ -13,16 +12,6 @@ import (
 	"encoding/json"
 )
 
-// TODO: parseMediaType should probably return an error if the mediaType is invalid
-// or at least do something that will be easier to debug
-
-type MediaType struct {
-	Supertype string
-	Subtype string
-	/* Full omits the parameters */
-	Full string
-}
-
 var dialer = &tls.Dialer{
 	NetDialer: &net.Dialer{},
 }
@@ -133,20 +122,6 @@ func Get(link *url.URL, accept string, tolerated []string, maxRedirects uint) (m
 	return dictionary, nil
 }
 
-func ParseMediaType(text string) (MediaType, error) {
-	matches := mediaTypeRegexp.FindStringSubmatch(text)
-
-	if len(matches) != 4 {
-		return MediaType{}, errors.New(text + " is not a valid media type")
-	}
-
-	return MediaType{
-		Supertype: matches[2],
-		Subtype: matches[3],
-		Full: matches[1],
-	}, nil
-}
-
 func parseStatusLine(text string) (string, error) {
 	matches := statusLineRegexp.FindStringSubmatch(text)
 
@@ -157,16 +132,16 @@ func parseStatusLine(text string) (string, error) {
 	return matches[1], nil
 }
 
-func parseContentType(text string) (MediaType, bool, error) {
+func parseContentType(text string) (*MediaType, bool, error) {
 	matches := contentTypeRegexp.FindStringSubmatch(text)
 
 	if len(matches) != 2 {
-		return MediaType{}, false, nil
+		return nil, false, nil
 	}
 
 	mediaType, err := ParseMediaType(matches[1])
 	if err != nil {
-		return MediaType{}, true, err
+		return nil, true, err
 	}
 
 	return mediaType, true, nil
@@ -207,7 +182,7 @@ func validateHeaders(buf *bufio.Reader, tolerated []string) error {
 			continue
 		}
 
-		if slices.Contains(tolerated, mediaType.Full) {
+		if mediaType.Matches(tolerated) {
 			contentTypeValidated = true
 		} else {
 			return errors.New("Response contains invalid content type " + mediaType.Full)
@@ -244,3 +219,33 @@ func findLocation(buf *bufio.Reader, baseLink *url.URL) (*url.URL, error) {
 	}
 	return nil, errors.New("Location is not present in headers")
 }
+
+type MediaType struct {
+	Supertype string
+	Subtype string
+	/* Full omits the parameters */
+	Full string
+}
+
+func ParseMediaType(text string) (*MediaType, error) {
+	matches := mediaTypeRegexp.FindStringSubmatch(text)
+
+	if len(matches) != 4 {
+		return nil, errors.New(text + " is not a valid media type")
+	}
+
+	return &MediaType{
+		Supertype: matches[2],
+		Subtype: matches[3],
+		Full: matches[1],
+	}, nil
+}
+
+func (m *MediaType) Matches(mediaTypes []string) bool {
+	for _, mediaType := range mediaTypes {
+		if m.Full == mediaType {
+			return true
+		}
+	}
+	return false
+}