feat: Allow circular references when the properties are not marked as required

This commit is contained in:
Benjamin Nolan (TwoWholeWorms)
2023-01-05 20:43:45 +01:00
committed by Dave Shanley
parent e8a954d5ae
commit ee504c543a
4 changed files with 196 additions and 43 deletions

View File

@@ -9,15 +9,21 @@ type CircularReferenceResult struct {
LoopIndex int LoopIndex int
LoopPoint *Reference LoopPoint *Reference
IsPolymorphicResult bool // if this result comes from a polymorphic loop. IsPolymorphicResult bool // if this result comes from a polymorphic loop.
IsInfiniteLoop bool // if all the definitions in the reference loop are marked as required, this is an infinite circular reference, thus is not allowed.
} }
func (c *CircularReferenceResult) GenerateJourneyPath() string { func (c *CircularReferenceResult) GenerateJourneyPath() string {
buf := strings.Builder{} buf := strings.Builder{}
for i, ref := range c.Journey { for i, ref := range c.Journey {
buf.WriteString(ref.Name) if i > 0 {
if i+1 < len(c.Journey) {
buf.WriteString(" -> ") buf.WriteString(" -> ")
} }
buf.WriteString(ref.Name)
// buf.WriteString(" (")
// buf.WriteString(ref.Definition)
// buf.WriteString(")")
} }
return buf.String() return buf.String()
} }

View File

@@ -35,16 +35,17 @@ const (
// Reference is a wrapper around *yaml.Node results to make things more manageable when performing // Reference is a wrapper around *yaml.Node results to make things more manageable when performing
// algorithms on data models. the *yaml.Node def is just a bit too low level for tracking state. // algorithms on data models. the *yaml.Node def is just a bit too low level for tracking state.
type Reference struct { type Reference struct {
Definition string Definition string
Name string Name string
Node *yaml.Node Node *yaml.Node
ParentNode *yaml.Node ParentNode *yaml.Node
Resolved bool Resolved bool
Circular bool Circular bool
Seen bool Seen bool
IsRemote bool IsRemote bool
RemoteLocation string RemoteLocation string
Path string // this won't always be available. Path string // this won't always be available.
RequiredRefProperties map[string][]string // definition names (eg, #/definitions/One) to a list of required properties on this definition which reference that definition
} }
// ReferenceMapped is a helper struct for mapped references put into sequence (we lose the key) // ReferenceMapped is a helper struct for mapped references put into sequence (we lose the key)
@@ -1614,10 +1615,12 @@ func (index *SpecIndex) FindComponent(componentId string, parent *yaml.Node) *Re
return nil return nil
} }
// FIXME: This is a potential security hole, and needs to be made optional (see log4j fiasco)
remoteLookup := func(id string) (*yaml.Node, *yaml.Node, error) { remoteLookup := func(id string) (*yaml.Node, *yaml.Node, error) {
return index.lookupRemoteReference(id) return index.lookupRemoteReference(id)
} }
// FIXME: As above
fileLookup := func(id string) (*yaml.Node, *yaml.Node, error) { fileLookup := func(id string) (*yaml.Node, *yaml.Node, error) {
return index.lookupFileReference(id) return index.lookupFileReference(id)
} }
@@ -1698,18 +1701,93 @@ func (index *SpecIndex) extractDefinitionsAndSchemas(schemasNode *yaml.Node, pat
name = schema.Value name = schema.Value
continue continue
} }
def := fmt.Sprintf("%s%s", pathPrefix, name) def := fmt.Sprintf("%s%s", pathPrefix, name)
ref := &Reference{ ref := &Reference{
Definition: def, Definition: def,
Name: name, Name: name,
Node: schema, Node: schema,
Path: fmt.Sprintf("$.components.schemas.%s", name), Path: fmt.Sprintf("$.components.schemas.%s", name),
ParentNode: schemasNode, ParentNode: schemasNode,
RequiredRefProperties: index.extractDefinitionRequiredRefProperties(schemasNode, map[string][]string{}),
} }
index.allSchemas[def] = ref index.allSchemas[def] = ref
} }
} }
// extractDefinitionRequiredRefProperties goes through the direct properties of a schema and extracts the map of required definitions from within it
func (index *SpecIndex) extractDefinitionRequiredRefProperties(schemaNode *yaml.Node, reqRefProps map[string][]string) map[string][]string {
if schemaNode == nil {
return reqRefProps
}
_, requiredSeqNode := utils.FindKeyNode("required", schemaNode.Content)
if requiredSeqNode == nil {
return reqRefProps
}
_, propertiesMapNode := utils.FindKeyNode("properties", schemaNode.Content)
if propertiesMapNode == nil {
// TODO: Log a warning on the resolver, because if you have required properties, but no actual properties, something is wrong
return reqRefProps
}
name := ""
for i, param := range propertiesMapNode.Content {
if i%2 == 0 {
name = param.Value
continue
}
_, paramPropertiesMapNode := utils.FindKeyNode("properties", param.Content)
if paramPropertiesMapNode != nil {
reqRefProps = index.extractDefinitionRequiredRefProperties(param, reqRefProps)
}
for _, key := range []string{"allOf", "oneOf", "anyOf"} {
_, ofNode := utils.FindKeyNode(key, param.Content)
if ofNode != nil {
for _, ofNodeItem := range ofNode.Content {
reqRefProps = index.extractRequiredReferenceProperties(ofNodeItem, name, reqRefProps)
}
}
}
}
for _, requiredPropertyNode := range requiredSeqNode.Content {
_, requiredPropDefNode := utils.FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content)
if requiredPropDefNode == nil {
continue
}
reqRefProps = index.extractRequiredReferenceProperties(requiredPropDefNode, requiredPropertyNode.Value, reqRefProps)
}
return reqRefProps
}
// extractRequiredReferenceProperties returns a map of definition names to the property or properties which reference it within a node
func (index *SpecIndex) extractRequiredReferenceProperties(requiredPropDefNode *yaml.Node, propName string, reqRefProps map[string][]string) map[string][]string {
isRef, _, defPath := utils.IsNodeRefValue(requiredPropDefNode)
if !isRef {
_, defItems := utils.FindKeyNode("items", requiredPropDefNode.Content)
if defItems != nil {
isRef, _, defPath = utils.IsNodeRefValue(defItems)
}
}
if /* still */ !isRef {
return reqRefProps
}
if _, ok := reqRefProps[defPath]; !ok {
reqRefProps[defPath] = []string{}
}
reqRefProps[defPath] = append(reqRefProps[defPath], propName)
return reqRefProps
}
func (index *SpecIndex) extractComponentParameters(paramsNode *yaml.Node, pathPrefix string) { func (index *SpecIndex) extractComponentParameters(paramsNode *yaml.Node, pathPrefix string) {
var name string var name string
for i, param := range paramsNode.Content { for i, param := range paramsNode.Content {
@@ -1899,11 +1977,13 @@ func (index *SpecIndex) FindComponentInRoot(componentId string) *Reference {
res, _ := path.Find(index.root) res, _ := path.Find(index.root)
if len(res) == 1 { if len(res) == 1 {
ref := &Reference{ ref := &Reference{
Definition: componentId, Definition: componentId,
Name: name, Name: name,
Node: res[0], Node: res[0],
Path: friendlySearch, Path: friendlySearch,
RequiredRefProperties: index.extractDefinitionRequiredRefProperties(res[0], map[string][]string{}),
} }
return ref return ref
} }
} }

View File

@@ -5,6 +5,7 @@ package resolver
import ( import (
"fmt" "fmt"
"github.com/pb33f/libopenapi/index" "github.com/pb33f/libopenapi/index"
"github.com/pb33f/libopenapi/utils" "github.com/pb33f/libopenapi/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -87,7 +88,6 @@ func (resolver *Resolver) GetNonPolymorphicCircularErrors() []*index.CircularRef
// re-organize the node tree. Make sure you have copied your original tree before running this (if you want to preserve // re-organize the node tree. Make sure you have copied your original tree before running this (if you want to preserve
// original data) // original data)
func (resolver *Resolver) Resolve() []*ResolvingError { func (resolver *Resolver) Resolve() []*ResolvingError {
mapped := resolver.specIndex.GetMappedReferencesSequenced() mapped := resolver.specIndex.GetMappedReferencesSequenced()
mappedIndex := resolver.specIndex.GetMappedReferences() mappedIndex := resolver.specIndex.GetMappedReferences()
@@ -98,7 +98,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
} }
schemas := resolver.specIndex.GetAllSchemas() schemas := resolver.specIndex.GetAllSchemas()
for s, schemaRef := range schemas { for s, schemaRef := range schemas {
if mappedIndex[s] == nil { if mappedIndex[s] == nil {
seenReferences := make(map[string]bool) seenReferences := make(map[string]bool)
@@ -118,8 +117,13 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
} }
for _, circRef := range resolver.circularReferences { for _, circRef := range resolver.circularReferences {
// If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one
if !circRef.IsInfiniteLoop {
continue
}
resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{ resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{
ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name), ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name),
Node: circRef.LoopPoint.Node, Node: circRef.LoopPoint.Node,
Path: circRef.GenerateJourneyPath(), Path: circRef.GenerateJourneyPath(),
}) })
@@ -130,7 +134,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
// CheckForCircularReferences Check for circular references, without resolving, a non-destructive run. // CheckForCircularReferences Check for circular references, without resolving, a non-destructive run.
func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError { func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
mapped := resolver.specIndex.GetMappedReferencesSequenced() mapped := resolver.specIndex.GetMappedReferencesSequenced()
mappedIndex := resolver.specIndex.GetMappedReferences() mappedIndex := resolver.specIndex.GetMappedReferences()
for _, ref := range mapped { for _, ref := range mapped {
@@ -138,6 +141,7 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
var journey []*index.Reference var journey []*index.Reference
resolver.VisitReference(ref.Reference, seenReferences, journey, false) resolver.VisitReference(ref.Reference, seenReferences, journey, false)
} }
schemas := resolver.specIndex.GetAllSchemas() schemas := resolver.specIndex.GetAllSchemas()
for s, schemaRef := range schemas { for s, schemaRef := range schemas {
if mappedIndex[s] == nil { if mappedIndex[s] == nil {
@@ -146,9 +150,15 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
resolver.VisitReference(schemaRef, seenReferences, journey, false) resolver.VisitReference(schemaRef, seenReferences, journey, false)
} }
} }
for _, circRef := range resolver.circularReferences { for _, circRef := range resolver.circularReferences {
// If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one
if !circRef.IsInfiniteLoop {
continue
}
resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{ resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{
ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name), ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name),
Node: circRef.LoopPoint.Node, Node: circRef.LoopPoint.Node,
Path: circRef.GenerateJourneyPath(), Path: circRef.GenerateJourneyPath(),
CircularReference: circRef, CircularReference: circRef,
@@ -161,7 +171,6 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
// VisitReference will visit a reference as part of a journey and will return resolved nodes. // VisitReference will visit a reference as part of a journey and will return resolved nodes.
func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]bool, journey []*index.Reference, resolve bool) []*yaml.Node { func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]bool, journey []*index.Reference, resolve bool) []*yaml.Node {
if ref.Resolved || ref.Seen { if ref.Resolved || ref.Seen {
return ref.Node.Content return ref.Node.Content
} }
@@ -173,34 +182,32 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b
seen[ref.Definition] = true seen[ref.Definition] = true
for _, r := range relatives { for _, r := range relatives {
// check if we have seen this on the journey before, if so! it's circular // check if we have seen this on the journey before, if so! it's circular
skip := false skip := false
for i, j := range journey { for i, j := range journey {
if j.Definition == r.Definition { if j.Definition == r.Definition {
foundDup := resolver.specIndex.GetMappedReferences()[r.Definition] foundDup := resolver.specIndex.GetMappedReferences()[r.Definition]
var circRef *index.CircularReferenceResult var circRef *index.CircularReferenceResult
if !foundDup.Circular { if !foundDup.Circular {
loop := append(journey, foundDup) loop := append(journey, foundDup)
circRef = &index.CircularReferenceResult{ circRef = &index.CircularReferenceResult{
Journey: loop, Journey: loop,
Start: foundDup, Start: foundDup,
LoopIndex: i, LoopIndex: i,
LoopPoint: foundDup, LoopPoint: foundDup,
IsInfiniteLoop: resolver.isInfiniteCircularDependency(foundDup, nil),
} }
resolver.circularReferences = append(resolver.circularReferences, circRef)
foundDup.Seen = true foundDup.Seen = true
foundDup.Circular = true foundDup.Circular = true
resolver.circularReferences = append(resolver.circularReferences, circRef)
} }
skip = true skip = true
} }
} }
if !skip { if !skip {
original := resolver.specIndex.GetMappedReferences()[r.Definition] original := resolver.specIndex.GetMappedReferences()[r.Definition]
resolved := resolver.VisitReference(original, seen, journey, resolve) resolved := resolver.VisitReference(original, seen, journey, resolve)
@@ -217,6 +224,30 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b
return ref.Node.Content return ref.Node.Content
} }
func (resolver *Resolver) isInfiniteCircularDependency(ref *index.Reference, initialRef *index.Reference) bool {
if ref == nil {
return false
}
for refDefinition := range ref.RequiredRefProperties {
r := resolver.specIndex.GetMappedReferences()[refDefinition]
if initialRef != nil && initialRef.Definition == r.Definition {
return true
}
ir := initialRef
if ir == nil {
ir = ref
}
if resolver.isInfiniteCircularDependency(r, ir) {
return true
}
}
return false
}
func (resolver *Resolver) extractRelatives(node *yaml.Node, func (resolver *Resolver) extractRelatives(node *yaml.Node,
foundRelatives map[string]bool, foundRelatives map[string]bool,
journey []*index.Reference, resolve bool) []*index.Reference { journey []*index.Reference, resolve bool) []*index.Reference {

View File

@@ -3,12 +3,13 @@ package utils
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath"
"gopkg.in/yaml.v3"
"net/url" "net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath"
"gopkg.in/yaml.v3"
) )
type Case int8 type Case int8
@@ -402,15 +403,50 @@ func IsNodeRefValue(node *yaml.Node) (bool, *yaml.Node, string) {
return false, nil, "" return false, nil, ""
} }
// IsPropertyNodeRequired will check if a node is required within circular references
func IsPropertyNodeRequired(node *yaml.Node, propertyName string) bool {
_, requiredSeqNode := FindKeyNode("required", node.Content)
if requiredSeqNode == nil {
return false
}
_, propertiesMapNode := FindKeyNode("properties", node.Content)
if propertiesMapNode == nil {
return false
}
for _, requiredPropertyNode := range requiredSeqNode.Content {
_, requiredPropDefNode := FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content)
if requiredPropDefNode == nil {
continue
}
isRef, _, defPath := IsNodeRefValue(requiredPropDefNode)
if isRef && defPath == propertyName {
return true
}
_, defItems := FindKeyNode("items", requiredPropDefNode.Content)
if defItems == nil {
continue
}
isRef, _, defPath = IsNodeRefValue(defItems)
if isRef && defPath == propertyName {
return true
}
}
return false
}
// FixContext will clean up a JSONpath string to be correctly traversable. // FixContext will clean up a JSONpath string to be correctly traversable.
func FixContext(context string) string { func FixContext(context string) string {
tokens := strings.Split(context, ".") tokens := strings.Split(context, ".")
var cleaned = []string{} var cleaned = []string{}
for i, t := range tokens { for i, t := range tokens {
if v, err := strconv.Atoi(t); err == nil { if v, err := strconv.Atoi(t); err == nil {
if v < 200 { // codes start here if v < 200 { // codes start here
if cleaned[i-1] != "" { if cleaned[i-1] != "" {
cleaned[i-1] += fmt.Sprintf("[%v]", t) cleaned[i-1] += fmt.Sprintf("[%v]", t)
@@ -421,8 +457,8 @@ func FixContext(context string) string {
continue continue
} }
cleaned = append(cleaned, strings.ReplaceAll(t, "(root)", "$")) cleaned = append(cleaned, strings.ReplaceAll(t, "(root)", "$"))
} }
return strings.Join(cleaned, ".") return strings.Join(cleaned, ".")
} }