#!/usr/bin/env python import sys, os, sets import ldap, ldap.dn, ldap.filter, ldif import Backend HOST = "ldap://localhost:3890" ROOTDN = "cn=root,dc=fam" PASSWORD = "barn" BASE = "dc=fam" REF_ATTRIBUTE = "member" KEY_ATTRIBUTE = "uid" TAG_ATTRIBUTE = "memberOf" ACCESS_ATTRIBUTE = "access" FILENAME = "/tmp/pivot.ldif" OBJECT_CLASS = "group" DN_ATTRIBUTE = "cn" # hasSubordinates: TRUE SCOPE_BASE = "0" SCOPE_ONE = "1" SCOPE_SUB = "2" class Lookups: def __init__(self, url): self.url = url self.ldap = None def __connect(self, force = False): if self.ldap: if not force: return self.ldap.unbind() self.ldap = ldap.initialize(self.url) try: self.ldap.simple_bind_s(ROOTDN, PASSWORD) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't do internal authenticate: %s" % ex.args[0]["desc"]) def read(self, dn, filtstr = "(objectClass=*)", attrs = None, retries = 1): try: self.__connect() results = self.ldap.search_s(dn, ldap.SCOPE_BASE, filtstr, attrs, 0) if not results: return None (dn, entry) = results[0] return entry except ldap.SERVER_DOWN, ex: if retries <= 0: raise sys.exc_type, sys.exc_value, sys.exc_traceback self.__connect(True) return self.read(dn, filtstr, attrs, retries - 1) def search(self, base, filtstr, attrs = [], retries = 1): try: self.__connect() return self.ldap.search_s(base, ldap.SCOPE_SUBTREE, filtstr, attrs, 0) except ldap.SERVER_DOWN, ex: if retries <= 0: raise sys.exc_type, sys.exc_value, sys.exc_traceback self.__connect(True) return self.search(base, filtstr, attrs, retries - 1) def modify(self, dn, mods, retries = 1): try: self.__connect() self.ldap.modify_s(dn, mods) except ldap.SERVER_DOWN, ex: if retries <= 0: raise sys.exc_type, sys.exc_value, sys.exc_traceback self.__connect(True) self.modify(dn, mods, retries - 1) class Storage: def __init__(self, filename): self.filename = filename self.entries = { } self.load() def load(self): if not os.path.exists(self.filename): return input = open(self.filename, 'r') reader = ldif.LDIFRecordList(input) reader.parse() input.close() self.entries = { } for (dn, entry) in reader.all_records: self.entries[dn] = entry def save(self): output = open(self.filename, 'w') print >> output, "# Overwritten automatically, do not edit\n" writer = ldif.LDIFWriter(output) for (dn, entry) in self.entries.items(): if (entry): print print dn print repr(entry) print writer.unparse(dn, entry) output.close() def __entry_for_dn(self, dn): if not self.entries.has_key(dn): self.entries[dn] = { } return self.entries[dn] def store(self, dn, attribute, value): if value is None: return entry = self.__entry_for_dn(dn) if not entry.has_key(attribute): entry[attribute] = [ ] if value not in entry[attribute]: entry[attribute].append(value) def remove(self, dn, attribute, value = None): entry = self.__entry_for_dn(dn) if entry.has_key(attribute): if value is None: del entry[attribute] elif value in entry[attribute]: entry[attribute] = [val for val in entry[attribute] if val != value] def has(self, dn, attribute, value = None): entry = self.__entry_for_dn(dn) if not entry.has_key(attribute): return False if value is None: return True return value in entry[attribute] def retrieve(self, dn, attribute): entry = self.__entry_for_dn(dn) if not entry.has_key(attribute): return [ ] return entry[attribute][:] # copy def list_attributes(self, dn): entry = self.__entry_for_dn(dn) return entry.keys()[:] # copy def list_dns(self): return self.entries.keys()[:] # copy def exists(self, dn): return dn in self.entries.keys() def delete(self, dn): if self.entries.has_key(dn): del self.entries[dn] class Static: def __init__(self, func): self.__call__ = func class Tags: def __init__(self, database, tags): self.database = database self.tags = tags def __refresh(self, force = False): if not force and self.tags is not None: return try: results = self.database.lookups.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE]) tags = { } for (dn, entry) in results: for attr in entry.values(): for value in attr: tags[value] = DN_ATTRIBUTE self.tags = tags except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) def __len__(self): self.__refresh() return len(self.tags) def __getitem__(self, k): self.__refresh() return self.tags[k] def __setitem__(self, k): assert False def __delitem__(self, k): assert False def __contains__(self, k): self.__refresh() return k in self.tags def __iter__(self): self.__refresh() return iter(self.tags) def items(self): self.__refresh() return self.tags.items() def from_dn(dn): try: parsed = ldap.dn.str2dn(dn) except ValueError: raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid dn: %s" % dn) return Tags.from_parsed_dn(parsed) from_dn = Static(from_dn) def from_parsed_dn(parsed): tags = { } for (typ, val, num) in parsed[0]: tags[val] = typ return Tags(None, tags) from_parsed_dn = Static(from_parsed_dn) def from_database(database): return Tags(database, None) from_database = Static(from_database) def parse_dn(dn): try: return ldap.dn.str2dn(dn) except: raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn: %s" % dn) def is_parsed_dn_parent(dn, parent): print dn print parent if len(dn) != len(parent) + 1: return False # Go backwards and validate each parent for i in range(-1, -1 - len(parent)): print i, dn[i], parent[i] if dn[i] != parent[i]: return False return True class Database(Backend.Database): def __init__(self, suffix): Backend.Database.__init__(self, suffix) self.lookups = Lookups(HOST) self.suffix_dn = parse_dn(self.suffix) self.storage = Storage(FILENAME) def __search_tag_keys(self, tags): if not len(tags): return [] # Build up a filter filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags] if len(filter) > 1: filter = "(&" + "".join(filter) + ")" else: filter = filter[0] try: # Search for all those guys results = self.lookups.search(BASE, filter, [ KEY_ATTRIBUTE ]) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't search ldap for tags: %s" % ex.args[0]["desc"]) return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]] def __search_key_dns(self, key): # Build up a filter filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key)) try: # Do the actual search results = self.lookups.search(BASE, filter) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) return [dn for (dn, entry) in results if dn] def __build_root_entry(self, tags, with_attrs = True): attrs = { "objectClass" : [ "top" ], "hasSubordinates" : [ ] } for (typ, val, num) in self.suffix_dn[0]: if not attrs.has_key(typ): attrs[typ] = [ ] attrs[typ].append(val) # Note that we don't access 'tags' unless attributes requested if with_attrs: attrs["hasSubordinates"].append(tags and "TRUE" or "FALSE") return (self.suffix, attrs) def __build_pivot_entry(self, tags, with_attrs = True): attrs = { REF_ATTRIBUTE : [ ], "objectClass" : [ OBJECT_CLASS ], "hasSubordinates" : [ "FALSE" ] } # Build up a DN, and relevant attrs rdn = [] for tag, typ in tags.items(): rdn.append((typ, tag, 1)) if not attrs.has_key(typ): attrs[typ] = [ ] attrs[typ].append(tag) dn = [ rdn ] dn.extend(self.suffix_dn) dn = ldap.dn.dn2str(dn) # Get out all the attributes if with_attrs: for key in self.__search_tag_keys(tags): attrs[REF_ATTRIBUTE].append(key) return (dn, attrs) def __build_storage_entry(self, parsed_dn, with_attrs = True): attrs = { } # Build up DN relevant attrs for (typ, val, num) in parsed_dn[0]: if not attrs.has_key(typ): attrs[typ] = [ ] attrs[typ].append(val) # All other storage attributes retrieved later if necessary return (ldap.dn.dn2str(parsed_dn), attrs) def __complete_results(self, args, entries): # TODO: Support sizelimit # TODO: Support a filter # Only return the attribute names? only_names = (args["attrsonly"] == "1") which_attrs = args["attrs"] all_attrs = (which_attrs == "all" or which_attrs == "*" or which_attrs == "+") which_attrs = which_attrs.split(" ") # Convert results from our map to a list with (dn, entry) tuples results = [ ] for (dn, entry) in entries.items(): # Retrieve extra value names extra = self.storage.list_attributes(dn) # Only return attribute names if only_names: for attr in extra: entry[attr] = [ "" ] for attr in entry: entry[attr] = [ "" ] # Return extra attribute names and values else: for attr in extra: values = self.storage.retrieve(dn, attr) if entry.has_key(attr): entry[attr].extend(values) else: entry[attr] = values # Remove all duplicates entry[attr] = list(set(entry[attr])) # Limit to the attributes requested if not all_attrs: for attr in entry.keys(): if attr not in which_attrs: del entry[attr] results.append((dn, entry)) return results def search(self, dn, args): results = { } parsed = parse_dn(dn) # Arguments sent scope = args["scope"] or SCOPE_BASE with_attrs = len(args["attrs"].strip()) > 0 # Start at the root if parsed == self.suffix_dn: tags = Tags.from_database(self) if scope == SCOPE_BASE or scope == SCOPE_SUB: (dn, entry) = self.__build_root_entry(tags, with_attrs) results[dn] = entry if scope == SCOPE_ONE or scope == SCOPE_SUB: # Process each tag individually, by default for (tag, typ) in tags.items(): (child, entry) = self.__build_pivot_entry({ tag : typ }, with_attrs) results[child] = entry # Process all extra storage items for child in self.storage.list_dns(): if child not in results: (child, entry) = self.__build_storage_entry(parse_dn(child)) results[child] = entry # Start at a tag elif is_parsed_dn_parent(parsed, self.suffix_dn): tags = Tags.from_parsed_dn(parsed) if scope == SCOPE_BASE or scope == SCOPE_SUB: (dn, entry) = self.__build_pivot_entry(tags, with_attrs) results[dn] = entry # Something in the database elif self.storage.exists(dn): if scope == SCOPE_BASE or scope == SCOPE_SUB: (dn, entry) = self.__build_storage_entry(parsed, with_attrs) results[dn] = entry # We don't have that base else: raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn) return self.__complete_results(args, results) def __build_key_mods(self, key, tags, op, mods): dns = self.__search_key_dns(key) if not dns: raise Backend.Error(Backend.CONSTRAINT_VIOLATION, "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key)) for dn in dns: if not mods.has_key(dn): mods[dn] = (key, []) for tag in tags: mods[dn][1].append((op, TAG_ATTRIBUTE, tag)) def __check_write_access(self, dn): if self.binddn == ROOTDN: return True return self.storage.has(dn, ACCESS_ATTRIBUTE, self.binddn) def add(self, dn, entry): parsed = parse_dn(dn) tags = Tags.from_parsed_dn(parsed) if parsed == self.suffix_dn: raise Backend.Error(Backend.ALREADY_EXISTS, "This entry already exists: %s" % dn) if not is_parsed_dn_parent(parsed, self.suffix_dn): raise Backend.Error(Backend.NO_SUCH_OBJECT, "Parent of '%s' does not exist or is not valid" % dn) if self.storage.exists(dn): raise Backend.Error(Backend.ALREADY_EXISTS, "This entry already exists: %s" % dn) if len(self.__search_tag_keys(tags)): raise Backend.Error(Backend.ALREADY_EXISTS, "This entry already exists: %s" % dn) # Everyone has implicit access to create a new group # Convert into a modify change set mods = [] for (attr, values) in entry.items(): for value in values: mods.append((ldap.MOD_ADD, attr, value)) # Add an access attribute for the creator if self.binddn and not ACCESS_ATTRIBUTE in entry : mods.append((ldap.MOD_ADD, ACCESS_ATTRIBUTE, self.binddn)) # Make the actual changes self.__change(parsed, mods, tags) # Save extra attributes to storage self.storage.save() def delete(self, dn, args): parsed = parse_dn(dn) tags = Tags.from_parsed_dn(parsed) if parsed == self.suffix_dn: raise Backend.Error(Backend.NOT_ALLOWED_ON_NONLEAF, "Cannot delete the root entry: %s" % dn) if not is_parsed_dn_parent(parsed, self.suffix_dn): raise Backend.Error(Backend.NO_SUCH_OBJECT, "Entry does not exist: %s" % dn) if not self.__check_write_access(dn): raise Backend.Error(Backend.INSUFFICIENT_ACCESS, "Access denied to delete entry: %s" % dn) mods = [] mods.append((ldap.MOD_DELETE, REF_ATTRIBUTE, None)) # Make the actual changes self.__change(parsed, mods, tags) # Delete extra attributes from storage self.storage.delete(dn) self.storage.save() def modify(self, dn, mods): parsed = parse_dn(dn) tags = Tags.from_parsed_dn(parsed) if dn == self.suffix: raise Backend.Error(Backend.INSUFFICIENT_ACCESS, "Cannot modify root dn of pivot area: %s" % dn) if not is_parsed_dn_parent (parsed, self.suffix_dn): raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn) if not self.__check_write_access(dn): raise Backend.Error(Backend.INSUFFICIENT_ACCESS, "Access denied to modify entry: %s" % dn) # Make the actual changes self.__change(parsed, mods, tags) # Save extra attributes to storage self.storage.save() def __change(self, parsed, mods, tags): add_keys = sets.Set() remove_keys = sets.Set() remove_all = False # Parse out all the adds and removes for (op, attr, value) in mods: # Process access attributes later if attr != REF_ATTRIBUTE: continue if op == ldap.MOD_ADD: if value: add_keys.add(value) elif op == ldap.MOD_REPLACE: remove_all = True if value: add_keys.add(value) elif op == ldap.MOD_DELETE: if value: remove_keys.add(value) else: remove_all = True else: continue # Remove all of the ref attribute if remove_all: for key in self.__search_tag_keys (tags): remove_keys.add(key) # Make them all unique, and non conflicting for key in add_keys.copy(): if key in remove_keys: remove_keys.remove(key) add_keys.remove(key) # Change them to DNs, and build mods objects for each dn keys_and_mods_by_dn = { } for key in add_keys: self.__build_key_mods(key, tags, ldap.MOD_ADD, keys_and_mods_by_dn) for key in remove_keys: self.__build_key_mods(key, tags, ldap.MOD_DELETE, keys_and_mods_by_dn) # Now perform the actual actions, combining errors errors = [] for (dn, (key, mod)) in keys_and_mods_by_dn.items(): try: print dn, mod self.lookups.modify(dn, mod) except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE): continue except ldap.NO_SUCH_OBJECT: errors.append(key) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't perform one of the modifications: %s" % ex.args[0]["desc"]) # Send back errors if errors: raise Backend.Error(Backend.CONSTRAINT_VIOLATION, "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors))) # Process other attributes now dn = ldap.dn.dn2str(parsed) for (op, attr, value) in mods: if attr == REF_ATTRIBUTE: continue if op == ldap.MOD_ADD: self.storage.store(dn, attr, value) elif op == ldap.MOD_REPLACE: self.storage.remove(dn, attr) self.storage.add(dn, attr, value) elif op == ldap.MOD_DELETE: self.storage.remove(dn, attr, value)