101 lines
3.5 KiB
Python
101 lines
3.5 KiB
Python
import os
|
|
import re
|
|
##
|
|
import requests
|
|
from lxml import etree
|
|
|
|
|
|
class Config(object):
|
|
default_xsd = 'http://schema.xml.r00t2.io/projects/he_ipv6.xsd'
|
|
|
|
def __init__(self, xml_path, *args, **kwargs):
|
|
self.xml_path = os.path.abspath(os.path.expanduser(xml_path))
|
|
if not os.path.isfile(self.xml_path):
|
|
raise ValueError('xml_path does not exist')
|
|
self.tree = None
|
|
self.ns_tree = None
|
|
self.xml = None
|
|
self.ns_xml = None
|
|
self.raw = None
|
|
self.xsd = None
|
|
self.defaults_parser = None
|
|
self.obj = None
|
|
self.tunnels = {}
|
|
self.creds = {}
|
|
self.parse_raw()
|
|
self.get_xsd()
|
|
self.populate_defaults()
|
|
self.validate()
|
|
|
|
def get_xsd(self):
|
|
raw_xsd = None
|
|
base_url = None
|
|
xsi = self.xml.nsmap.get('xsi', 'http://www.w3.org/2001/XMLSchema-instance')
|
|
schemaLocation = '{{{0}}}schemaLocation'.format(xsi)
|
|
schemaURL = self.xml.attrib.get(schemaLocation, self.default_xsd)
|
|
split_url = schemaURL.split()
|
|
if len(split_url) == 2: # a properly defined schemaLocation
|
|
schemaURL = split_url[1]
|
|
else:
|
|
schemaURL = split_url[0] # a LAZY schemaLocation
|
|
if schemaURL.startswith('file://'):
|
|
schemaURL = re.sub(r'^file://', r'', schemaURL)
|
|
with open(schemaURL, 'rb') as fh:
|
|
raw_xsd = fh.read()
|
|
base_url = os.path.dirname(schemaURL) + '/'
|
|
else:
|
|
req = requests.get(schemaURL)
|
|
if not req.ok:
|
|
raise RuntimeError('Could not download XSD')
|
|
raw_xsd = req.content
|
|
base_url = os.path.split(req.url)[0] + '/' # This makes me feel dirty.
|
|
self.xsd = etree.XMLSchema(etree.XML(raw_xsd, base_url = base_url))
|
|
return(None)
|
|
|
|
def parse_raw(self, parser = None):
|
|
if not self.raw:
|
|
with open(self.xml_path, 'rb') as fh:
|
|
self.raw = fh.read()
|
|
self.xml = etree.fromstring(self.raw, parser = parser)
|
|
self.ns_xml = etree.fromstring(self.raw, parser = parser)
|
|
self.tree = self.xml.getroottree()
|
|
self.ns_tree = self.ns_xml.getroottree()
|
|
self.tree.xinclude()
|
|
self.ns_tree.xinclude()
|
|
self.strip_ns()
|
|
return(None)
|
|
|
|
def populate_defaults(self):
|
|
if not self.xsd:
|
|
self.get_xsd()
|
|
if not self.defaults_parser:
|
|
self.defaults_parser = etree.XMLParser(schema = self.xsd, attribute_defaults = True)
|
|
self.parse_raw(parser = self.defaults_parser)
|
|
return(None)
|
|
|
|
def remove_defaults(self):
|
|
self.parse_raw()
|
|
return(None)
|
|
|
|
def strip_ns(self, obj = None):
|
|
# https://stackoverflow.com/questions/30232031/how-can-i-strip-namespaces-out-of-an-lxml-tree/30233635#30233635
|
|
xpathq = "descendant-or-self::*[namespace-uri()!='']"
|
|
if not obj:
|
|
for x in (self.tree, self.xml):
|
|
for e in x.xpath(xpathq):
|
|
e.tag = etree.QName(e).localname
|
|
elif isinstance(obj, (etree._Element, etree._ElementTree)):
|
|
obj = copy.deepcopy(obj)
|
|
for e in obj.xpath(xpathq):
|
|
e.tag = etree.QName(e).localname
|
|
return(obj)
|
|
else:
|
|
raise ValueError('Did not know how to parse obj parameter')
|
|
return(None)
|
|
|
|
def validate(self):
|
|
if not self.xsd:
|
|
self.get_xsd()
|
|
self.xsd.assertValid(self.ns_tree)
|
|
return(None)
|