summaryrefslogtreecommitdiff
path: root/Pivot.py
diff options
context:
space:
mode:
Diffstat (limited to 'Pivot.py')
-rw-r--r--Pivot.py397
1 files changed, 397 insertions, 0 deletions
diff --git a/Pivot.py b/Pivot.py
new file mode 100644
index 0000000..28cea92
--- /dev/null
+++ b/Pivot.py
@@ -0,0 +1,397 @@
+#!/usr/bin/env python
+
+import ldap, ldap.dn, ldap.filter
+import Backend, sets
+
+HOST = "ldap://localhost:3890"
+BINDDN = "cn=root,dc=fam"
+PASSWORD = "barn"
+BASE = "dc=fam"
+REF_ATTRIBUTE = "member"
+KEY_ATTRIBUTE = "uid"
+TAG_ATTRIBUTE = "memberOf"
+
+OBJECT_CLASS = "groupOfNames"
+DN_ATTRIBUTE = "cn"
+
+# hasSubordinates: TRUE
+
+SCOPE_BASE = "0"
+SCOPE_ONE = "1"
+SCOPE_SUB = "2"
+
+class Storage:
+ def __init__(self, url):
+ self.url = url
+ self.ldap = None
+
+ def __connect(self, force = False):
+ if self.ldap:
+ if not force:
+ return
+ self.ldap.unbind()
+ self.ldap = ldap.initialize(self.url)
+ try:
+ self.ldap.simple_bind_s(BINDDN, PASSWORD)
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't do internal authenticate: %s" % ex.args[0]["desc"])
+
+ def read(self, dn, filtstr = "(objectClass=*)", attrs = None, retries = 1):
+ try:
+ self.__connect()
+ results = self.ldap.search_s(dn, ldap.SCOPE_BASE, filtstr, attrs, 0)
+
+ if not results:
+ return None
+ (dn, entry) = results[0]
+ return entry
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ return self.read(dn, filtstr, attrs, retries - 1)
+
+ def search(self, base, filtstr, attrs = [], retries = 1):
+ try:
+ self.__connect()
+ return self.ldap.search_s(base, ldap.SCOPE_SUBTREE, filtstr, attrs, 0)
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ return self.search(base, filtstr, attrs, retries - 1)
+
+ def modify(self, dn, mods, retries = 1):
+ try:
+ self.__connect()
+ self.ldap.modify_s(dn, mods)
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ self.modify(dn, mods, retries - 1)
+
+
+class Static:
+ def __init__(self, func):
+ self.__call__ = func
+
+class Tags:
+ def __init__(self, database, tags):
+ self.database = database
+ self.tags = tags
+
+ def __refresh(self, force = False):
+ if not force and self.tags is not None:
+ return
+ try:
+ results = self.database.storage.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE])
+
+ tags = { }
+ for (dn, entry) in results:
+ for attr in entry.values():
+ for value in attr:
+ tags[value] = DN_ATTRIBUTE
+ self.tags = tags
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for keys: %s" % ex.args[0]["desc"])
+
+ def __len__(self):
+ self.__refresh()
+ return len(self.tags)
+ def __getitem__(self, k):
+ self.__refresh()
+ return self.tags[k]
+ def __setitem__(self, k):
+ assert False
+ def __delitem__(self, k):
+ assert False
+ def __contains__(self, k):
+ self.__refresh()
+ return k in self.tags
+ def __iter__(self):
+ self.__refresh()
+ return iter(self.tags)
+ def items(self):
+ self.__refresh()
+ return self.tags.items()
+
+
+ def from_dn(dn):
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except ValueError:
+ raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid dn: %s" % dn)
+ return Tags.from_parsed_dn(parsed)
+ from_dn = Static(from_dn)
+
+ def from_parsed_dn(parsed):
+ tags = { }
+ for (typ, val, num) in parsed[0]:
+ tags[val] = typ
+ return Tags(None, tags)
+ from_parsed_dn = Static(from_parsed_dn)
+
+ def from_database(database):
+ return Tags(database, None)
+ from_database = Static(from_database)
+
+
+def is_parsed_dn_parent (dn, parent):
+ if len(dn) <= len(parent):
+ return False
+ # Go backwards and validate each parent
+ for i in range(-1, -1 - len(parent)):
+ print i, dn[i], parent[i]
+ if dn[i] != parent[i]:
+ return False
+ return True
+
+class Database(Backend.Database):
+ def __init__(self, suffix):
+ Backend.Database.__init__(self, suffix)
+ self.storage = Storage(HOST)
+ try:
+ self.suffix_dn = ldap.dn.str2dn(self.suffix)
+ except ValueError:
+ raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid suffix dn")
+
+ def __search_tag_keys(self, tags):
+
+ if not len(tags):
+ return []
+
+ # Build up a filter
+ filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags]
+ if len(filter) > 1:
+ filter = "(&" + "".join(filter) + ")"
+ else:
+ filter = filter[0]
+
+ try:
+ # Search for all those guys
+ results = self.storage.search(BASE, filter, [ KEY_ATTRIBUTE ])
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for tags: %s" % ex.args[0]["desc"])
+
+ return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]]
+
+
+ def __search_key_dns(self, key):
+
+ # Build up a filter
+ filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key))
+
+ try:
+ # Do the actual search
+ results = self.storage.search(BASE, filter)
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for keys: %s" % ex.args[0]["desc"])
+
+ return [dn for (dn, entry) in results if dn]
+
+
+ def __build_root_entry(self, tags, any_attrs = True):
+ attrs = {
+ "objectClass" : [ "top" ],
+ "hasSubordinates" : [ ]
+ }
+
+ for (typ, val, num) in self.suffix_dn[0]:
+ if not attrs.has_key(typ):
+ attrs[typ] = [ ]
+ attrs[typ].append(val)
+
+ # Note that we don't access 'tags' unless attributes requested
+ if any_attrs:
+ attrs["hasSubordinates"].append(tags and "FALSE" or "TRUE")
+
+ return (self.suffix, attrs)
+
+
+ def __build_pivot_entry(self, tags, any_attrs = True):
+ attrs = {
+ REF_ATTRIBUTE : [ ],
+ "objectClass" : [ OBJECT_CLASS ],
+ "hasSubordinates" : [ "FALSE" ]
+ }
+
+ # Build up a DN, and relevant attrs
+ rdn = []
+ for tag, typ in tags.items():
+ rdn.append((typ, tag, 1))
+ if not attrs.has_key(typ):
+ attrs[typ] = [ ]
+ attrs[typ].append(tag)
+ dn = [ rdn ]
+ dn.extend(self.suffix_dn)
+ dn = ldap.dn.dn2str(dn)
+
+ # Get out all the attributes
+ if any_attrs:
+ for key in self.__search_tag_keys(tags):
+ attrs[REF_ATTRIBUTE].append(key)
+
+ return (dn, attrs)
+
+
+
+ def __limit_results(self, args, results):
+ # TODO: Support sizelimit
+ # TODO: Support a filter
+
+ # Only return the attribute names?
+ if args["attrsonly"] == "1":
+ for (dn, attrs) in results:
+ for attr in attrs:
+ attrs[attr] = [ "" ]
+
+ # Only return these attributes?
+ which = args["attrs"]
+ if which != "all" and which != "*" and which != "+":
+ which = which.split(" ")
+ for (dn, attrs) in results:
+ for attr in attrs.keys():
+ if attr not in which:
+ del attrs[attr]
+
+
+ def search(self, dn, args):
+ results = []
+
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except:
+ raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in search: %s" % dn)
+
+ # Arguments sent
+ scope = args["scope"] or SCOPE_BASE
+ any_attrs = len(args["attrs"].strip()) > 0
+
+ # Start at the root
+ if parsed == self.suffix_dn:
+ tags = Tags.from_database(self)
+ if scope == SCOPE_BASE or scope == SCOPE_SUB:
+ results.append(self.__build_root_entry(tags, any_attrs))
+ if scope == SCOPE_ONE or scope == SCOPE_SUB:
+ # Process each tag individually, by default
+ for (tag, typ) in tags.items():
+ results.append(self.__build_pivot_entry({ tag : typ }, any_attrs))
+
+ # Start at a tag
+ elif is_parsed_dn_parent (parsed, self.suffix_dn):
+ tags = Tags.from_parsed_dn(parsed)
+ if scope == SCOPE_BASE or scope == SCOPE_SUB:
+ results.append(self.__build_pivot_entry(tags, any_attrs))
+
+ # We don't have that base
+ else:
+ raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn)
+
+ self.__limit_results(args, results)
+ return results
+
+
+ def __build_key_mods(self, key, tags, op, mods):
+ dns = self.__search_key_dns(key)
+ if not dns:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key))
+ for dn in dns:
+ if not mods.has_key(dn):
+ mods[dn] = (key, [])
+ for tag in tags:
+ mods[dn][1].append((op, TAG_ATTRIBUTE, tag))
+
+ def modify(self, dn, mods):
+
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except:
+ raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in modify: %s" % dn)
+
+ if dn == self.suffix:
+ raise Backend.Error(Backend.INSUFFICIENT_ACCESS,
+ "Cannot modify root dn of pivot area: %s" % dn)
+
+ if not is_parsed_dn_parent (parsed, self.suffix_dn):
+ raise Backend.Error(Backend.NO_SUCH_OBJECT,
+ "DN '%s' does not exist" % dn)
+
+ tags = Tags.from_parsed_dn(parsed)
+
+ add_keys = sets.Set()
+ remove_keys = sets.Set()
+ remove_all = False
+
+ # Parse out all the adds and removes
+ for (op, attr, value) in mods:
+
+ if attr != REF_ATTRIBUTE:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Cannot modify '%s' attribute" % attr)
+
+ if op == ldap.MOD_ADD:
+ if value:
+ add_keys.add(value)
+ elif op == ldap.MOD_REPLACE:
+ remove_all = True
+ if value:
+ add_keys.add(value)
+ elif op == ldap.MOD_DELETE:
+ if value:
+ remove_keys.add(value)
+ else:
+ remove_all = True
+ else:
+ continue
+
+ # Remove all of the ref attribute
+ if remove_all:
+ for key in self.__search_tag_keys (tags):
+ remove_keys.add(key)
+
+ # Make them all unique, and non conflicting
+ for key in add_keys.copy():
+ if key in remove_keys:
+ remove_keys.remove(key)
+ add_keys.remove(key)
+
+ # Change them to DNs, and build mods objects for each dn
+ keys_and_mods_by_dn = { }
+ for key in add_keys:
+ self.__build_key_mods(key, tags, ldap.MOD_ADD, keys_and_mods_by_dn)
+ for key in remove_keys:
+ self.__build_key_mods(key, tags, ldap.MOD_DELETE, keys_and_mods_by_dn)
+
+
+ # Now perform the actual actions, combining errors
+ errors = []
+ for (dn, (key, mod)) in keys_and_mods_by_dn.items():
+ try:
+ print dn, mod
+ self.storage.modify(dn, mod)
+ except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE):
+ continue
+ except ldap.NO_SUCH_OBJECT:
+ errors.append(key)
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't perform one of the modifications: %s" % ex.args[0]["desc"])
+
+ # Send back errors
+ if errors:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors)))
+
+