diff options
Diffstat (limited to 'Pivot.py')
-rw-r--r-- | Pivot.py | 397 |
1 files changed, 397 insertions, 0 deletions
diff --git a/Pivot.py b/Pivot.py new file mode 100644 index 0000000..28cea92 --- /dev/null +++ b/Pivot.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import ldap, ldap.dn, ldap.filter +import Backend, sets + +HOST = "ldap://localhost:3890" +BINDDN = "cn=root,dc=fam" +PASSWORD = "barn" +BASE = "dc=fam" +REF_ATTRIBUTE = "member" +KEY_ATTRIBUTE = "uid" +TAG_ATTRIBUTE = "memberOf" + +OBJECT_CLASS = "groupOfNames" +DN_ATTRIBUTE = "cn" + +# hasSubordinates: TRUE + +SCOPE_BASE = "0" +SCOPE_ONE = "1" +SCOPE_SUB = "2" + +class Storage: + def __init__(self, url): + self.url = url + self.ldap = None + + def __connect(self, force = False): + if self.ldap: + if not force: + return + self.ldap.unbind() + self.ldap = ldap.initialize(self.url) + try: + self.ldap.simple_bind_s(BINDDN, PASSWORD) + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't do internal authenticate: %s" % ex.args[0]["desc"]) + + def read(self, dn, filtstr = "(objectClass=*)", attrs = None, retries = 1): + try: + self.__connect() + results = self.ldap.search_s(dn, ldap.SCOPE_BASE, filtstr, attrs, 0) + + if not results: + return None + (dn, entry) = results[0] + return entry + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + return self.read(dn, filtstr, attrs, retries - 1) + + def search(self, base, filtstr, attrs = [], retries = 1): + try: + self.__connect() + return self.ldap.search_s(base, ldap.SCOPE_SUBTREE, filtstr, attrs, 0) + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + return self.search(base, filtstr, attrs, retries - 1) + + def modify(self, dn, mods, retries = 1): + try: + self.__connect() + self.ldap.modify_s(dn, mods) + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + self.modify(dn, mods, retries - 1) + + +class Static: + def __init__(self, func): + self.__call__ = func + +class Tags: + def __init__(self, database, tags): + self.database = database + self.tags = tags + + def __refresh(self, force = False): + if not force and self.tags is not None: + return + try: + results = self.database.storage.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE]) + + tags = { } + for (dn, entry) in results: + for attr in entry.values(): + for value in attr: + tags[value] = DN_ATTRIBUTE + self.tags = tags + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) + + def __len__(self): + self.__refresh() + return len(self.tags) + def __getitem__(self, k): + self.__refresh() + return self.tags[k] + def __setitem__(self, k): + assert False + def __delitem__(self, k): + assert False + def __contains__(self, k): + self.__refresh() + return k in self.tags + def __iter__(self): + self.__refresh() + return iter(self.tags) + def items(self): + self.__refresh() + return self.tags.items() + + + def from_dn(dn): + try: + parsed = ldap.dn.str2dn(dn) + except ValueError: + raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid dn: %s" % dn) + return Tags.from_parsed_dn(parsed) + from_dn = Static(from_dn) + + def from_parsed_dn(parsed): + tags = { } + for (typ, val, num) in parsed[0]: + tags[val] = typ + return Tags(None, tags) + from_parsed_dn = Static(from_parsed_dn) + + def from_database(database): + return Tags(database, None) + from_database = Static(from_database) + + +def is_parsed_dn_parent (dn, parent): + if len(dn) <= len(parent): + return False + # Go backwards and validate each parent + for i in range(-1, -1 - len(parent)): + print i, dn[i], parent[i] + if dn[i] != parent[i]: + return False + return True + +class Database(Backend.Database): + def __init__(self, suffix): + Backend.Database.__init__(self, suffix) + self.storage = Storage(HOST) + try: + self.suffix_dn = ldap.dn.str2dn(self.suffix) + except ValueError: + raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid suffix dn") + + def __search_tag_keys(self, tags): + + if not len(tags): + return [] + + # Build up a filter + filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags] + if len(filter) > 1: + filter = "(&" + "".join(filter) + ")" + else: + filter = filter[0] + + try: + # Search for all those guys + results = self.storage.search(BASE, filter, [ KEY_ATTRIBUTE ]) + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for tags: %s" % ex.args[0]["desc"]) + + return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]] + + + def __search_key_dns(self, key): + + # Build up a filter + filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key)) + + try: + # Do the actual search + results = self.storage.search(BASE, filter) + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) + + return [dn for (dn, entry) in results if dn] + + + def __build_root_entry(self, tags, any_attrs = True): + attrs = { + "objectClass" : [ "top" ], + "hasSubordinates" : [ ] + } + + for (typ, val, num) in self.suffix_dn[0]: + if not attrs.has_key(typ): + attrs[typ] = [ ] + attrs[typ].append(val) + + # Note that we don't access 'tags' unless attributes requested + if any_attrs: + attrs["hasSubordinates"].append(tags and "FALSE" or "TRUE") + + return (self.suffix, attrs) + + + def __build_pivot_entry(self, tags, any_attrs = True): + attrs = { + REF_ATTRIBUTE : [ ], + "objectClass" : [ OBJECT_CLASS ], + "hasSubordinates" : [ "FALSE" ] + } + + # Build up a DN, and relevant attrs + rdn = [] + for tag, typ in tags.items(): + rdn.append((typ, tag, 1)) + if not attrs.has_key(typ): + attrs[typ] = [ ] + attrs[typ].append(tag) + dn = [ rdn ] + dn.extend(self.suffix_dn) + dn = ldap.dn.dn2str(dn) + + # Get out all the attributes + if any_attrs: + for key in self.__search_tag_keys(tags): + attrs[REF_ATTRIBUTE].append(key) + + return (dn, attrs) + + + + def __limit_results(self, args, results): + # TODO: Support sizelimit + # TODO: Support a filter + + # Only return the attribute names? + if args["attrsonly"] == "1": + for (dn, attrs) in results: + for attr in attrs: + attrs[attr] = [ "" ] + + # Only return these attributes? + which = args["attrs"] + if which != "all" and which != "*" and which != "+": + which = which.split(" ") + for (dn, attrs) in results: + for attr in attrs.keys(): + if attr not in which: + del attrs[attr] + + + def search(self, dn, args): + results = [] + + try: + parsed = ldap.dn.str2dn(dn) + except: + raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in search: %s" % dn) + + # Arguments sent + scope = args["scope"] or SCOPE_BASE + any_attrs = len(args["attrs"].strip()) > 0 + + # Start at the root + if parsed == self.suffix_dn: + tags = Tags.from_database(self) + if scope == SCOPE_BASE or scope == SCOPE_SUB: + results.append(self.__build_root_entry(tags, any_attrs)) + if scope == SCOPE_ONE or scope == SCOPE_SUB: + # Process each tag individually, by default + for (tag, typ) in tags.items(): + results.append(self.__build_pivot_entry({ tag : typ }, any_attrs)) + + # Start at a tag + elif is_parsed_dn_parent (parsed, self.suffix_dn): + tags = Tags.from_parsed_dn(parsed) + if scope == SCOPE_BASE or scope == SCOPE_SUB: + results.append(self.__build_pivot_entry(tags, any_attrs)) + + # We don't have that base + else: + raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn) + + self.__limit_results(args, results) + return results + + + def __build_key_mods(self, key, tags, op, mods): + dns = self.__search_key_dns(key) + if not dns: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key)) + for dn in dns: + if not mods.has_key(dn): + mods[dn] = (key, []) + for tag in tags: + mods[dn][1].append((op, TAG_ATTRIBUTE, tag)) + + def modify(self, dn, mods): + + try: + parsed = ldap.dn.str2dn(dn) + except: + raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in modify: %s" % dn) + + if dn == self.suffix: + raise Backend.Error(Backend.INSUFFICIENT_ACCESS, + "Cannot modify root dn of pivot area: %s" % dn) + + if not is_parsed_dn_parent (parsed, self.suffix_dn): + raise Backend.Error(Backend.NO_SUCH_OBJECT, + "DN '%s' does not exist" % dn) + + tags = Tags.from_parsed_dn(parsed) + + add_keys = sets.Set() + remove_keys = sets.Set() + remove_all = False + + # Parse out all the adds and removes + for (op, attr, value) in mods: + + if attr != REF_ATTRIBUTE: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Cannot modify '%s' attribute" % attr) + + if op == ldap.MOD_ADD: + if value: + add_keys.add(value) + elif op == ldap.MOD_REPLACE: + remove_all = True + if value: + add_keys.add(value) + elif op == ldap.MOD_DELETE: + if value: + remove_keys.add(value) + else: + remove_all = True + else: + continue + + # Remove all of the ref attribute + if remove_all: + for key in self.__search_tag_keys (tags): + remove_keys.add(key) + + # Make them all unique, and non conflicting + for key in add_keys.copy(): + if key in remove_keys: + remove_keys.remove(key) + add_keys.remove(key) + + # Change them to DNs, and build mods objects for each dn + keys_and_mods_by_dn = { } + for key in add_keys: + self.__build_key_mods(key, tags, ldap.MOD_ADD, keys_and_mods_by_dn) + for key in remove_keys: + self.__build_key_mods(key, tags, ldap.MOD_DELETE, keys_and_mods_by_dn) + + + # Now perform the actual actions, combining errors + errors = [] + for (dn, (key, mod)) in keys_and_mods_by_dn.items(): + try: + print dn, mod + self.storage.modify(dn, mod) + except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE): + continue + except ldap.NO_SUCH_OBJECT: + errors.append(key) + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't perform one of the modifications: %s" % ex.args[0]["desc"]) + + # Send back errors + if errors: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors))) + + |