Browse Source

add various checks that posts, comments, and outbox elements are not forged

Benton Edmondson 1 year ago
parent
commit
20adced6ea
6 changed files with 112 additions and 18 deletions
  1. 3 3
      client/client.go
  2. 8 0
      pub/activity.go
  3. 20 1
      pub/actor.go
  4. 18 5
      pub/collection.go
  5. 13 4
      pub/common.go
  6. 50 5
      pub/post.go

+ 3 - 3
client/client.go

@@ -43,7 +43,7 @@ func FetchUnknown(input any, source *url.URL) (object.Object, *url.URL, error) {
 		return nil, nil, err
 	}
 	/* Refetch if necessary */
-	if id != nil && (source == nil || source.Host != id.Host || len(obj) <= 2) {
+	if id != nil && (source == nil || source.String() != id.String() || len(obj) <= 2) {
 		obj, source, err = FetchURL(id)
 		if err != nil {
 			return nil, nil, err
@@ -55,8 +55,8 @@ func FetchUnknown(input any, source *url.URL) (object.Object, *url.URL, error) {
 		} else if err != nil {
 			return nil, nil, err
 		}
-		if id != nil && source.Host != id.Host {
-			return nil, nil, errors.New("received response with forged ID")
+		if id != nil && source.String() != id.String() {
+			return nil, nil, errors.New("received response with forged identifier")
 		}
 	}
 

+ 8 - 0
pub/activity.go

@@ -130,6 +130,14 @@ func (a *Activity) Actor() Tangible {
 	return a.actor
 }
 
+func (a *Activity) ActorIdentifier() *url.URL {
+	if a.actorErr != nil {
+		return nil
+	}
+
+	return a.actor.Identifier()
+}
+
 func (a *Activity) Target() Tangible {
 	return a.target
 }

+ 20 - 1
pub/actor.go

@@ -69,7 +69,22 @@ func NewActorFromObject(o object.Object, id *url.URL) (*Actor, error) {
 	a.pfp, a.pfpErr = getBestLink(o, "icon", "image")
 	a.banner, a.bannerErr = getBestLink(o, "image", "image")
 
-	a.posts, a.postsErr = getCollection(o, "outbox", a.id)
+	a.posts, a.postsErr = getCollection(o, "outbox", a.id, func(input any, source *url.URL) Tangible {
+		activity, err := NewActivity(input, source)
+		if err != nil {
+			return NewFailure(err)
+		}
+
+		if id == nil {
+			return NewFailure(errors.New("activity was performed by a different actor (this actor has no identifier)"))
+		}
+
+		if activity.ActorIdentifier() == nil || activity.ActorIdentifier().String() != id.String() {
+			return NewFailure(errors.New("activity was performed by a different actor"))
+		}
+
+		return activity
+	})
 	return a, nil
 }
 
@@ -199,6 +214,10 @@ func (a *Actor) Timestamp() time.Time {
 	}
 }
 
+func (a *Actor) Identifier() *url.URL {
+	return a.id
+}
+
 func (a *Actor) Banner() (string, *mime.MediaType, bool) {
 	if a.bannerErr != nil {
 		return "", nil, false

+ 18 - 5
pub/collection.go

@@ -38,17 +38,19 @@ type Collection struct {
 
 	size    uint64
 	sizeErr error
+
+	construct func(any, *url.URL,) Tangible
 }
 
-func NewCollection(input any, source *url.URL) (*Collection, error) {
+func NewCollection(input any, source *url.URL, construct func(any, *url.URL) Tangible) (*Collection, error) {
 	o, id, err := client.FetchUnknown(input, source)
 	if err != nil {
 		return nil, err
 	}
-	return NewCollectionFromObject(o, id)
+	return NewCollectionFromObject(o, id, construct)
 }
 
-func NewCollectionFromObject(o object.Object, id *url.URL) (*Collection, error) {
+func NewCollectionFromObject(o object.Object, id *url.URL, construct func(any, *url.URL) Tangible) (*Collection, error) {
 	c := &Collection{}
 	c.id = id
 	var err error
@@ -62,6 +64,8 @@ func NewCollectionFromObject(o object.Object, id *url.URL) (*Collection, error)
 		return nil, fmt.Errorf("%w: %s is not a Collection", ErrWrongType, c.kind)
 	}
 
+	c.construct = construct
+
 	if c.kind == "Collection" || c.kind == "CollectionPage" {
 		c.elements, c.elementsErr = o.GetList("items")
 	} else {
@@ -137,7 +141,7 @@ func (c *Collection) harvestWithEmptyCount(amount uint, startingPoint uint, empt
 		i := i
 		wg.Add(1)
 		go func() {
-			fromThisPage[i] = NewTangible(c.elements[i+startingPoint], c.id)
+			fromThisPage[i] = c.construct(c.elements[i+startingPoint], c.id)
 			wg.Done()
 		}()
 	}
@@ -150,7 +154,7 @@ func (c *Collection) harvestWithEmptyCount(amount uint, startingPoint uint, empt
 			fromLaterPages, nextCollection, nextStartingPoint = []Tangible{}, nil, 0
 		} else if c.nextErr != nil {
 			fromLaterPages, nextCollection, nextStartingPoint = []Tangible{NewFailure(c.nextErr)}, nil, 0
-		} else if next, err := NewCollection(c.next, c.id); err != nil {
+		} else if next, err := NewCollection(c.next, c.id, c.construct); err != nil {
 			fromLaterPages, nextCollection, nextStartingPoint = []Tangible{NewFailure(err)}, nil, 0
 		} else {
 			fromLaterPages, nextCollection, nextStartingPoint = next.harvestWithEmptyCount(amount-amountFromThisPage, 0, emptyCount)
@@ -161,5 +165,14 @@ func (c *Collection) harvestWithEmptyCount(amount uint, startingPoint uint, empt
 
 	wg.Wait()
 
+	n := 0
+	for _, element := range fromThisPage {
+		if element != nil {
+			fromThisPage[n] = element
+			n += 1
+		}
+	}
+	fromThisPage = fromThisPage[:n]
+
 	return append(fromThisPage, fromLaterPages...), nextCollection, nextStartingPoint
 }

+ 13 - 4
pub/common.go

@@ -85,13 +85,13 @@ func getPostOrActor(o object.Object, key string, source *url.URL) Tangible {
 	return fetched
 }
 
-func getCollection(o object.Object, key string, source *url.URL) (*Collection, error) {
+func getCollection(o object.Object, key string, source *url.URL, construct func(any, *url.URL) Tangible) (*Collection, error) {
 	reference, err := o.GetAny(key)
 	if err != nil {
 		return nil, err
 	}
 
-	fetched, err := NewCollection(reference, source)
+	fetched, err := NewCollection(reference, source, construct)
 	if err != nil {
 		return nil, err
 	}
@@ -111,12 +111,21 @@ func getActor(o object.Object, key string, source *url.URL) (*Actor, error) {
 	return fetched, nil
 }
 
+func getAndFetchUnkown(o object.Object, key string, source *url.URL) (object.Object, *url.URL, error) {
+	reference, err := o.GetAny(key)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	return client.FetchUnknown(reference, source)
+}
+
 func NewTangible(input any, source *url.URL) Tangible {
 	fetched := New(input, source)
 	if tangible, ok := fetched.(Tangible); ok {
 		return tangible
 	}
-	return NewFailure(errors.New("item is non-Tangible"))
+	return NewFailure(errors.New("item is a collection"))
 }
 
 func New(input any, source *url.URL) any {
@@ -148,7 +157,7 @@ func New(input any, source *url.URL) any {
 		return NewFailure(err)
 	}
 
-	result, err = NewCollectionFromObject(o, id)
+	result, err = NewCollectionFromObject(o, id, NewTangible)
 	if err == nil {
 		return result
 	} else if !errors.Is(err, ErrWrongType) {

+ 50 - 5
pub/post.go

@@ -30,7 +30,8 @@ type Post struct {
 	createdErr error
 	edited     time.Time
 	editedErr  error
-	parent     any
+	parentObject     object.Object
+	parentIdentifier *url.URL
 	parentErr  error
 
 	// just as body dies completely if members die,
@@ -60,6 +61,10 @@ func NewPostFromObject(o object.Object, id *url.URL) (*Post, error) {
 		return nil, err
 	}
 
+	if p.kind == "Tombstone" {
+		return nil, errors.New("post was deleted")
+	}
+
 	if !slices.Contains([]string{
 		"Article", "Audio", "Document", "Image", "Note", "Page", "Video",
 	}, p.kind) {
@@ -70,7 +75,7 @@ func NewPostFromObject(o object.Object, id *url.URL) (*Post, error) {
 	p.body, p.bodyLinks, p.bodyErr = o.GetMarkup("content", "mediaType")
 	p.created, p.createdErr = o.GetTime("published")
 	p.edited, p.editedErr = o.GetTime("updated")
-	p.parent, p.parentErr = o.GetAny("inReplyTo")
+	p.parentObject, p.parentIdentifier, p.parentErr = getAndFetchUnkown(o, "inReplyTo", p.id)
 
 	if p.kind == "Audio" || p.kind == "Video" || p.kind == "Image" {
 		p.media, p.mediaErr = getBestLinkShorthand(o, "url", strings.ToLower(p.kind))
@@ -83,14 +88,47 @@ func NewPostFromObject(o object.Object, id *url.URL) (*Post, error) {
 	go func() { p.creators = getActors(o, "attributedTo", p.id); wg.Done() }()
 	go func() { p.recipients = getActors(o, "audience", p.id); wg.Done() }()
 	go func() { p.attachments, p.attachmentsErr = getLinks(o, "attachment"); wg.Done() }()
+
+	constructComment := func(input any, source *url.URL) Tangible {
+		comment, err := NewPost(input, source)
+		if err != nil {
+			return NewFailure(err)
+		}
+
+		if id == nil {
+			return NewFailure(errors.New("comment does not reference this parent (parent lacks an identifier)"))
+		}
+
+		if comment.ParentIdentifier() == nil || comment.ParentIdentifier().String() != id.String() {
+			return NewFailure(errors.New("comment does not reference this parent"))
+		}
+
+		return comment
+	}
+
 	go func() {
-		p.comments, p.commentsErr = getCollection(o, "replies", p.id)
+		p.comments, p.commentsErr = getCollection(o, "replies", p.id, constructComment)
 		if errors.Is(p.commentsErr, object.ErrKeyNotPresent) {
-			p.comments, p.commentsErr = getCollection(o, "comments", p.id)
+			p.comments, p.commentsErr = getCollection(o, "comments", p.id, constructComment)
 		}
 		wg.Done()
 	}()
 	wg.Wait()
+
+	/* Ensure that creators come from the same host as the post itself */
+	for _, creator := range p.creators {
+		if asActor, isActor := creator.(*Actor); isActor {
+			if asActor.Identifier() == nil && id == nil {
+				continue
+			}
+
+			if (asActor.Identifier() == nil || id == nil) || asActor.Identifier().Host != id.Host {
+				return nil, errors.New("post contains forged creators")
+			}
+		}
+		/* These are necessarily Failure types, so don't need to be checked */
+	}
+
 	return p, nil
 }
 
@@ -115,7 +153,7 @@ func (p *Post) Parents(quantity uint) ([]Tangible, Tangible) {
 	if p.parentErr != nil {
 		return []Tangible{NewFailure(p.parentErr)}, nil
 	}
-	fetchedParent, fetchedParentErr := NewPost(p.parent, p.id)
+	fetchedParent, fetchedParentErr := NewPostFromObject(p.parentObject, p.parentIdentifier)
 	if fetchedParentErr != nil {
 		return []Tangible{NewFailure(fetchedParentErr)}, nil
 	}
@@ -126,6 +164,13 @@ func (p *Post) Parents(quantity uint) ([]Tangible, Tangible) {
 	return append([]Tangible{fetchedParent}, fetchedParentParents...), fetchedParentFrontier
 }
 
+func (p *Post) ParentIdentifier() *url.URL {
+	if p.parentErr != nil {
+		return nil
+	}
+	return p.parentIdentifier
+}
+
 func (p *Post) header(width int) string {
 	output := ""