feat: Allow circular references when the properties are not marked as required

This commit is contained in:
Benjamin Nolan (TwoWholeWorms)
2023-01-05 20:43:45 +01:00
committed by Dave Shanley
parent e8a954d5ae
commit ee504c543a
4 changed files with 196 additions and 43 deletions

View File

@@ -9,15 +9,21 @@ type CircularReferenceResult struct {
LoopIndex int
LoopPoint *Reference
IsPolymorphicResult bool // if this result comes from a polymorphic loop.
IsInfiniteLoop bool // if all the definitions in the reference loop are marked as required, this is an infinite circular reference, thus is not allowed.
}
func (c *CircularReferenceResult) GenerateJourneyPath() string {
buf := strings.Builder{}
for i, ref := range c.Journey {
buf.WriteString(ref.Name)
if i+1 < len(c.Journey) {
if i > 0 {
buf.WriteString(" -> ")
}
buf.WriteString(ref.Name)
// buf.WriteString(" (")
// buf.WriteString(ref.Definition)
// buf.WriteString(")")
}
return buf.String()
}

View File

@@ -35,16 +35,17 @@ const (
// Reference is a wrapper around *yaml.Node results to make things more manageable when performing
// algorithms on data models. the *yaml.Node def is just a bit too low level for tracking state.
type Reference struct {
Definition string
Name string
Node *yaml.Node
ParentNode *yaml.Node
Resolved bool
Circular bool
Seen bool
IsRemote bool
RemoteLocation string
Path string // this won't always be available.
Definition string
Name string
Node *yaml.Node
ParentNode *yaml.Node
Resolved bool
Circular bool
Seen bool
IsRemote bool
RemoteLocation string
Path string // this won't always be available.
RequiredRefProperties map[string][]string // definition names (eg, #/definitions/One) to a list of required properties on this definition which reference that definition
}
// ReferenceMapped is a helper struct for mapped references put into sequence (we lose the key)
@@ -1614,10 +1615,12 @@ func (index *SpecIndex) FindComponent(componentId string, parent *yaml.Node) *Re
return nil
}
// FIXME: This is a potential security hole, and needs to be made optional (see log4j fiasco)
remoteLookup := func(id string) (*yaml.Node, *yaml.Node, error) {
return index.lookupRemoteReference(id)
}
// FIXME: As above
fileLookup := func(id string) (*yaml.Node, *yaml.Node, error) {
return index.lookupFileReference(id)
}
@@ -1698,18 +1701,93 @@ func (index *SpecIndex) extractDefinitionsAndSchemas(schemasNode *yaml.Node, pat
name = schema.Value
continue
}
def := fmt.Sprintf("%s%s", pathPrefix, name)
ref := &Reference{
Definition: def,
Name: name,
Node: schema,
Path: fmt.Sprintf("$.components.schemas.%s", name),
ParentNode: schemasNode,
Definition: def,
Name: name,
Node: schema,
Path: fmt.Sprintf("$.components.schemas.%s", name),
ParentNode: schemasNode,
RequiredRefProperties: index.extractDefinitionRequiredRefProperties(schemasNode, map[string][]string{}),
}
index.allSchemas[def] = ref
}
}
// extractDefinitionRequiredRefProperties goes through the direct properties of a schema and extracts the map of required definitions from within it
func (index *SpecIndex) extractDefinitionRequiredRefProperties(schemaNode *yaml.Node, reqRefProps map[string][]string) map[string][]string {
if schemaNode == nil {
return reqRefProps
}
_, requiredSeqNode := utils.FindKeyNode("required", schemaNode.Content)
if requiredSeqNode == nil {
return reqRefProps
}
_, propertiesMapNode := utils.FindKeyNode("properties", schemaNode.Content)
if propertiesMapNode == nil {
// TODO: Log a warning on the resolver, because if you have required properties, but no actual properties, something is wrong
return reqRefProps
}
name := ""
for i, param := range propertiesMapNode.Content {
if i%2 == 0 {
name = param.Value
continue
}
_, paramPropertiesMapNode := utils.FindKeyNode("properties", param.Content)
if paramPropertiesMapNode != nil {
reqRefProps = index.extractDefinitionRequiredRefProperties(param, reqRefProps)
}
for _, key := range []string{"allOf", "oneOf", "anyOf"} {
_, ofNode := utils.FindKeyNode(key, param.Content)
if ofNode != nil {
for _, ofNodeItem := range ofNode.Content {
reqRefProps = index.extractRequiredReferenceProperties(ofNodeItem, name, reqRefProps)
}
}
}
}
for _, requiredPropertyNode := range requiredSeqNode.Content {
_, requiredPropDefNode := utils.FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content)
if requiredPropDefNode == nil {
continue
}
reqRefProps = index.extractRequiredReferenceProperties(requiredPropDefNode, requiredPropertyNode.Value, reqRefProps)
}
return reqRefProps
}
// extractRequiredReferenceProperties returns a map of definition names to the property or properties which reference it within a node
func (index *SpecIndex) extractRequiredReferenceProperties(requiredPropDefNode *yaml.Node, propName string, reqRefProps map[string][]string) map[string][]string {
isRef, _, defPath := utils.IsNodeRefValue(requiredPropDefNode)
if !isRef {
_, defItems := utils.FindKeyNode("items", requiredPropDefNode.Content)
if defItems != nil {
isRef, _, defPath = utils.IsNodeRefValue(defItems)
}
}
if /* still */ !isRef {
return reqRefProps
}
if _, ok := reqRefProps[defPath]; !ok {
reqRefProps[defPath] = []string{}
}
reqRefProps[defPath] = append(reqRefProps[defPath], propName)
return reqRefProps
}
func (index *SpecIndex) extractComponentParameters(paramsNode *yaml.Node, pathPrefix string) {
var name string
for i, param := range paramsNode.Content {
@@ -1899,11 +1977,13 @@ func (index *SpecIndex) FindComponentInRoot(componentId string) *Reference {
res, _ := path.Find(index.root)
if len(res) == 1 {
ref := &Reference{
Definition: componentId,
Name: name,
Node: res[0],
Path: friendlySearch,
Definition: componentId,
Name: name,
Node: res[0],
Path: friendlySearch,
RequiredRefProperties: index.extractDefinitionRequiredRefProperties(res[0], map[string][]string{}),
}
return ref
}
}

View File

@@ -5,6 +5,7 @@ package resolver
import (
"fmt"
"github.com/pb33f/libopenapi/index"
"github.com/pb33f/libopenapi/utils"
"gopkg.in/yaml.v3"
@@ -87,7 +88,6 @@ func (resolver *Resolver) GetNonPolymorphicCircularErrors() []*index.CircularRef
// re-organize the node tree. Make sure you have copied your original tree before running this (if you want to preserve
// original data)
func (resolver *Resolver) Resolve() []*ResolvingError {
mapped := resolver.specIndex.GetMappedReferencesSequenced()
mappedIndex := resolver.specIndex.GetMappedReferences()
@@ -98,7 +98,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
}
schemas := resolver.specIndex.GetAllSchemas()
for s, schemaRef := range schemas {
if mappedIndex[s] == nil {
seenReferences := make(map[string]bool)
@@ -118,8 +117,13 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
}
for _, circRef := range resolver.circularReferences {
// If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one
if !circRef.IsInfiniteLoop {
continue
}
resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{
ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name),
ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name),
Node: circRef.LoopPoint.Node,
Path: circRef.GenerateJourneyPath(),
})
@@ -130,7 +134,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError {
// CheckForCircularReferences Check for circular references, without resolving, a non-destructive run.
func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
mapped := resolver.specIndex.GetMappedReferencesSequenced()
mappedIndex := resolver.specIndex.GetMappedReferences()
for _, ref := range mapped {
@@ -138,6 +141,7 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
var journey []*index.Reference
resolver.VisitReference(ref.Reference, seenReferences, journey, false)
}
schemas := resolver.specIndex.GetAllSchemas()
for s, schemaRef := range schemas {
if mappedIndex[s] == nil {
@@ -146,9 +150,15 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
resolver.VisitReference(schemaRef, seenReferences, journey, false)
}
}
for _, circRef := range resolver.circularReferences {
// If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one
if !circRef.IsInfiniteLoop {
continue
}
resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{
ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name),
ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name),
Node: circRef.LoopPoint.Node,
Path: circRef.GenerateJourneyPath(),
CircularReference: circRef,
@@ -161,7 +171,6 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError {
// VisitReference will visit a reference as part of a journey and will return resolved nodes.
func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]bool, journey []*index.Reference, resolve bool) []*yaml.Node {
if ref.Resolved || ref.Seen {
return ref.Node.Content
}
@@ -173,34 +182,32 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b
seen[ref.Definition] = true
for _, r := range relatives {
// check if we have seen this on the journey before, if so! it's circular
skip := false
for i, j := range journey {
if j.Definition == r.Definition {
foundDup := resolver.specIndex.GetMappedReferences()[r.Definition]
var circRef *index.CircularReferenceResult
if !foundDup.Circular {
loop := append(journey, foundDup)
circRef = &index.CircularReferenceResult{
Journey: loop,
Start: foundDup,
LoopIndex: i,
LoopPoint: foundDup,
Journey: loop,
Start: foundDup,
LoopIndex: i,
LoopPoint: foundDup,
IsInfiniteLoop: resolver.isInfiniteCircularDependency(foundDup, nil),
}
resolver.circularReferences = append(resolver.circularReferences, circRef)
foundDup.Seen = true
foundDup.Circular = true
resolver.circularReferences = append(resolver.circularReferences, circRef)
}
skip = true
}
}
if !skip {
original := resolver.specIndex.GetMappedReferences()[r.Definition]
resolved := resolver.VisitReference(original, seen, journey, resolve)
@@ -217,6 +224,30 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b
return ref.Node.Content
}
func (resolver *Resolver) isInfiniteCircularDependency(ref *index.Reference, initialRef *index.Reference) bool {
if ref == nil {
return false
}
for refDefinition := range ref.RequiredRefProperties {
r := resolver.specIndex.GetMappedReferences()[refDefinition]
if initialRef != nil && initialRef.Definition == r.Definition {
return true
}
ir := initialRef
if ir == nil {
ir = ref
}
if resolver.isInfiniteCircularDependency(r, ir) {
return true
}
}
return false
}
func (resolver *Resolver) extractRelatives(node *yaml.Node,
foundRelatives map[string]bool,
journey []*index.Reference, resolve bool) []*index.Reference {

View File

@@ -3,12 +3,13 @@ package utils
import (
"encoding/json"
"fmt"
"github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath"
"gopkg.in/yaml.v3"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath"
"gopkg.in/yaml.v3"
)
type Case int8
@@ -402,15 +403,50 @@ func IsNodeRefValue(node *yaml.Node) (bool, *yaml.Node, string) {
return false, nil, ""
}
// IsPropertyNodeRequired will check if a node is required within circular references
func IsPropertyNodeRequired(node *yaml.Node, propertyName string) bool {
_, requiredSeqNode := FindKeyNode("required", node.Content)
if requiredSeqNode == nil {
return false
}
_, propertiesMapNode := FindKeyNode("properties", node.Content)
if propertiesMapNode == nil {
return false
}
for _, requiredPropertyNode := range requiredSeqNode.Content {
_, requiredPropDefNode := FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content)
if requiredPropDefNode == nil {
continue
}
isRef, _, defPath := IsNodeRefValue(requiredPropDefNode)
if isRef && defPath == propertyName {
return true
}
_, defItems := FindKeyNode("items", requiredPropDefNode.Content)
if defItems == nil {
continue
}
isRef, _, defPath = IsNodeRefValue(defItems)
if isRef && defPath == propertyName {
return true
}
}
return false
}
// FixContext will clean up a JSONpath string to be correctly traversable.
func FixContext(context string) string {
tokens := strings.Split(context, ".")
var cleaned = []string{}
for i, t := range tokens {
if v, err := strconv.Atoi(t); err == nil {
if v < 200 { // codes start here
if cleaned[i-1] != "" {
cleaned[i-1] += fmt.Sprintf("[%v]", t)
@@ -421,8 +457,8 @@ func FixContext(context string) string {
continue
}
cleaned = append(cleaned, strings.ReplaceAll(t, "(root)", "$"))
}
return strings.Join(cleaned, ".")
}