summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStef Walter <stef@thewalter.net>2008-06-06 17:35:32 +0000
committerStef Walter <stef@thewalter.net>2008-06-06 17:35:32 +0000
commitea4a294b84254fb19b2edd24683fa6501970d049 (patch)
treecf84f87373ab27562f47093a0b3af70fb0eef156
Initial import
-rw-r--r--Backend.py393
-rw-r--r--Pivot.py397
-rw-r--r--slapd-pivot.py18
3 files changed, 808 insertions, 0 deletions
diff --git a/Backend.py b/Backend.py
new file mode 100644
index 0000000..6dc5497
--- /dev/null
+++ b/Backend.py
@@ -0,0 +1,393 @@
+#!/usr/bin/env python
+
+from __future__ import with_statement
+import sys, os, time, re
+import threading, mutex
+import SocketServer, StringIO
+import ldap, ldif
+
+debug = 0
+
+OPERATIONS_ERROR = 0x01
+PROTOCOL_ERROR = 0x02
+TIMELIMIT_EXCEEDED = 0x03
+SIZELIMIT_EXCEEDED = 0x04
+COMPARE_FALSE = 0x05
+COMPARE_TRUE = 0x06
+AUTH_METHOD_NOT_SUPPORTED = 0x07
+STRONG_AUTH_NOT_SUPPORTED = 0x07
+STRONG_AUTH_REQUIRED = 0x08
+STRONGER_AUTH_REQUIRED = 0x08
+PARTIAL_RESULTS = 0x09
+ADMINLIMIT_EXCEEDED = 0x0b
+CONFIDENTIALITY_REQUIRED = 0x0d
+SASL_BIND_IN_PROGRESS = 0x0e
+NO_SUCH_ATTRIBUTE = 0x10
+UNDEFINED_TYPE = 0x11
+INAPPROPRIATE_MATCHING = 0x12
+CONSTRAINT_VIOLATION = 0x13
+TYPE_OR_VALUE_EXISTS = 0x14
+INVALID_SYNTAX = 0x15
+NO_SUCH_OBJECT = 0x20
+ALIAS_PROBLEM = 0x21
+INVALID_DN_SYNTAX = 0x22
+IS_LEAF = 0x23
+ALIAS_DEREF_PROBLEM = 0x24
+X_PROXY_AUTHZ_FAILURE = 0x2F
+INAPPROPRIATE_AUTH = 0x30
+INVALID_CREDENTIALS = 0x31
+INSUFFICIENT_ACCESS = 0x32
+BUSY = 0x33
+UNAVAILABLE = 0x34
+UNWILLING_TO_PERFORM = 0x35
+LOOP_DETECT = 0x36
+NAMING_VIOLATION = 0x40
+OBJECT_CLASS_VIOLATION = 0x41
+NOT_ALLOWED_ON_NONLEAF = 0x42
+NOT_ALLOWED_ON_RDN = 0x43
+ALREADY_EXISTS = 0x44
+NO_OBJECT_CLASS_MODS = 0x45
+RESULTS_TOO_LARGE = 0x46
+AFFECTS_MULTIPLE_DSAS = 0x47
+OTHER_ERROR = 0x50
+
+def split_argument(line):
+ parts = line.strip().split(':', 2)
+ name = parts[0].strip()
+ value = ""
+ if len(parts) == 2:
+ value = parts[1].strip()
+ return (name, value)
+
+
+class Parser(ldif.LDIFParser):
+ def __init__(self, input):
+ ldif.LDIFParser.__init__(self, input)
+ self.dn = None
+ self.entry = None
+ def handle(self, dn, entry):
+ if self.entry is None:
+ self.dn = dn
+ self.entry = entry
+
+
+class Database:
+ def __init__(self, suffix):
+ self.suffix = suffix
+ self.binddn = None
+ self.remote = None
+ self.mutex = threading.Lock()
+
+ # Overridable to handle specific command
+ def add(self, dn, entry):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Add not implemented")
+ def bind(self, dn, args):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Bind not implemented")
+ def compare(self, dn, entry):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Compare not implemented")
+ def delete(self, dn, args):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Delete not implemented")
+ def modify(self, dn, modlist):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Modify not implemented")
+ def modrdn(self, dn, args):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "ModRDN not implemented")
+ def search(self, dn, args):
+ raise VirtualError, (UNWILLING_TO_PERFORM, "Search not implemented")
+
+ # Overridable to handle all processing
+ def process(self, command, dn, block):
+ try:
+
+ # This we handle specially
+ if command == "MODIFY":
+ return self.process_modify_internal(dn, block)
+
+ # This we handle specially
+ elif command == "ADD":
+ return self.process_add_internal(dn, block)
+
+ # All the rest we split up, multiple args go into arrays
+ args = Arguments()
+ while True:
+ line = block.readline()
+ if len(line) == 0:
+ break
+ (name, value) = split_argument(line)
+ args.add(name, value)
+
+ if command == "BIND":
+ self.bind(dn, args)
+ elif command == "COMPARE":
+ result = self.compare(dn, args)
+ return (result, "", None)
+ elif command == "DELETE":
+ self.delete(dn, args)
+ elif command == "MODIFY":
+ self.modify(dn, args)
+ elif command == "MODRDN":
+ self.modrdn(dn, args)
+ elif command == "SEARCH":
+ return self.process_search_internal(dn, args)
+ elif command == "UNBIND":
+ assert False # should have been handled in caller
+ else:
+ return UNWILLING_TO_PERFORM, "Unsupported operation %s" % command
+
+ return (0, "", [])
+
+ except Error, ex:
+ return (ex.code, ex.info, None)
+
+ def process_locked(self, command, dn, block, binddn, remote):
+ with self.mutex:
+ self.binddn = binddn
+ self.remote = remote
+ result = self.process(command, dn, block)
+ return result
+
+ def process_add_internal(self, dn, block):
+ parser = Parser(block)
+ parser.parse()
+ self.add(dn, parser.entry)
+ return (0, "", [])
+
+
+ def process_compare_intersal(self, dn, block):
+ parser = Parser(block)
+ parser.parse()
+ result = self.compare(dn, parser.entry)
+ return (result, "", [])
+
+
+ def process_search_internal(self, dn, args):
+ results = []
+ data = self.search(args["base"] or dn, args)
+ for (dn, entry) in data:
+ result = StringIO.StringIO()
+ writer = ldif.LDIFWriter(result)
+ writer.unparse(dn, entry)
+ results.append(result.getvalue())
+ return (0, "", results)
+
+
+ def process_modify_internal(self, dn, block):
+
+ op = None
+ attr = None
+ batch = 0
+ mods = []
+
+ while True:
+ line = block.readline()
+ if len(line) == 0:
+ break
+ line = line.strip()
+
+ # A break between different mods
+ if line == "-":
+
+ # Latch batch was empty, delete/replace all
+ if batch == 0:
+ if op == ldap.MOD_DELETE:
+ mods.append((op, attr, None))
+ elif op == ldap.MOD_REPLACE:
+ mods.append((op, attr, None))
+
+ batch = 0
+ op = None
+ attr = None
+ continue
+
+ # The current line
+ (name, value) = split_argument(line)
+
+ # Don't have a mod type yet
+ if op is None:
+ attr = value
+ if name == "add":
+ op = ldap.MOD_ADD
+ elif name == "replace":
+ op = ldap.MOD_REPLACE
+ elif name == "delete":
+ op = ldap.MOD_DELETE
+ else:
+ op = None
+
+ # Have a op, add values
+ elif name == attr:
+ assert attr is not None
+ mods.append((op, attr, value))
+ batch += 1
+
+ print mods
+ self.modify(dn, mods)
+ return (0, "", [])
+
+
+class Error(Exception):
+ """Exception to be returned to server"""
+
+ def __init__(self, code = OPERATIONS_ERROR, info = ""):
+ self.code = code
+ self.info = info
+
+ def __str__(self):
+ return "%d: %s" % (self.code, self.info)
+
+class Arguments:
+ def __init__(self):
+ self.dict = {}
+ def add(self, name, value):
+ if self.dict.has_key(name):
+ if type(self.dict[name]) != type([]):
+ self.dict[name] = [self.dict[name]]
+ self.dict[name].append(value)
+ else:
+ self.dict[name] = value
+
+ self.dict[name] = value
+ def __getitem__(self, name):
+ if self.dict.has_key(name):
+ return self.dict[name]
+ return None
+ def __len__(self):
+ return len(self.dict)
+
+class Connection(SocketServer.BaseRequestHandler):
+
+ def __init__(self, request, client_address, server):
+ server.unique += 1
+ self.identifier = server.unique
+ self.block_regex = re.compile("\r?\n[ \t]*\r?\n")
+ SocketServer.BaseRequestHandler.__init__(self, request, client_address, server)
+
+ def trace(self, message):
+ global debug
+ if debug:
+ prefix = "%04d " % self.identifier
+ lines = message.split("\n")
+ print >> sys.stderr, prefix + lines[0]
+ print >> sys.stderr, "\n".join([prefix + "*** " + line for line in lines[1:]])
+
+ def handle(self):
+
+ self.trace("CONNECTED")
+
+ req = self.request
+ req.setblocking(1)
+ req.settimeout(None)
+
+ extra = ""
+ block = None
+ while True:
+ data = extra + req.recv(1024)
+
+ # End of connection
+ if len(data) == 0:
+ break
+
+ parts = self.block_regex.split(data)
+ if len(parts) > 1:
+ block = unicode(parts[0], "utf-8", "strict")
+ break
+ extra = parts[0]
+
+ if block:
+ self.trace("REQUEST\n%s\n" % block)
+ self.handle_block(req, StringIO.StringIO(block))
+
+ self.trace("DISCONNECTING")
+ self.request.close()
+
+ def handle_block(self, req, block):
+
+ line = block.readline()
+ command = line.strip()
+
+ # Disconnect immediately on certain occasions
+ if not command:
+ return False
+ elif command == "UNBIND":
+ return False
+
+ suffixes = []
+ binddn = None
+ remote = None
+ ssf = None
+ msgid = None
+ dn = None
+
+ while True:
+ off = block.tell()
+
+ line = block.readline()
+ (name, value) = split_argument(line)
+
+ if name == "suffix":
+ suffixes.append(value)
+ elif name == "msgid" and msgid is None:
+ msgid = value
+ elif name == "dn" and dn is None:
+ dn = value.lower()
+ elif name == "binddn" and binddn is None:
+ binddn = value.lower()
+ elif name == "peername" and remote is None:
+ remote = value
+ elif name == "ssf" and ssf is None:
+ ssf = value
+ else: # Return this line and continue
+ block.seek(off)
+ break
+
+ code = 0
+ info = ""
+ data = None
+
+ if len(suffixes) == 0:
+ code = OPERATIONS_ERROR
+ info = "No suffix specified"
+ elif len(suffixes) > 1:
+ code = OPERATIONS_ERROR
+ info = "Multiple suffixes not supported"
+ else:
+ database = self.server.find_database(suffixes[0])
+ (code, info, data) = database.process_locked(command, dn, block, binddn, remote)
+
+ if data:
+ for dat in data:
+ self.trace("DATA\n%s\n" % dat)
+ req.sendall(dat.strip("\n") + "\n\n")
+
+ result = "RESULT"
+ if info:
+ result += "\ninfo: %s" % info
+ # BUG: Current OpenLDAP always wants code last
+ result += "\ncode: %d" % code
+ self.trace("RESPONSE\n%s" % result.strip("\n"))
+ req.sendall(result)
+
+ return True
+
+
+class Server(SocketServer.ThreadingMixIn, SocketServer.UnixStreamServer):
+ daemon_threads = True
+ allow_reuse_address = True
+
+ def __init__(self, address, DatabaseClass, debug = False):
+ if (os.path.exists(address)):
+ os.unlink(address)
+ SocketServer.UnixStreamServer.__init__(self, address, Connection)
+ self.DatabaseClass = DatabaseClass
+ self.__databases = {}
+ self.unique = 0
+
+ def find_database(self, suffix):
+ suffix = suffix.lower()
+ if self.__databases.has_key(suffix):
+ database = self.__databases[suffix]
+ else:
+ database = self.DatabaseClass(suffix)
+ self.__databases[suffix] = database
+ return database
+
+
diff --git a/Pivot.py b/Pivot.py
new file mode 100644
index 0000000..28cea92
--- /dev/null
+++ b/Pivot.py
@@ -0,0 +1,397 @@
+#!/usr/bin/env python
+
+import ldap, ldap.dn, ldap.filter
+import Backend, sets
+
+HOST = "ldap://localhost:3890"
+BINDDN = "cn=root,dc=fam"
+PASSWORD = "barn"
+BASE = "dc=fam"
+REF_ATTRIBUTE = "member"
+KEY_ATTRIBUTE = "uid"
+TAG_ATTRIBUTE = "memberOf"
+
+OBJECT_CLASS = "groupOfNames"
+DN_ATTRIBUTE = "cn"
+
+# hasSubordinates: TRUE
+
+SCOPE_BASE = "0"
+SCOPE_ONE = "1"
+SCOPE_SUB = "2"
+
+class Storage:
+ def __init__(self, url):
+ self.url = url
+ self.ldap = None
+
+ def __connect(self, force = False):
+ if self.ldap:
+ if not force:
+ return
+ self.ldap.unbind()
+ self.ldap = ldap.initialize(self.url)
+ try:
+ self.ldap.simple_bind_s(BINDDN, PASSWORD)
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't do internal authenticate: %s" % ex.args[0]["desc"])
+
+ def read(self, dn, filtstr = "(objectClass=*)", attrs = None, retries = 1):
+ try:
+ self.__connect()
+ results = self.ldap.search_s(dn, ldap.SCOPE_BASE, filtstr, attrs, 0)
+
+ if not results:
+ return None
+ (dn, entry) = results[0]
+ return entry
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ return self.read(dn, filtstr, attrs, retries - 1)
+
+ def search(self, base, filtstr, attrs = [], retries = 1):
+ try:
+ self.__connect()
+ return self.ldap.search_s(base, ldap.SCOPE_SUBTREE, filtstr, attrs, 0)
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ return self.search(base, filtstr, attrs, retries - 1)
+
+ def modify(self, dn, mods, retries = 1):
+ try:
+ self.__connect()
+ self.ldap.modify_s(dn, mods)
+
+ except ldap.SERVER_DOWN, ex:
+ if retries <= 0:
+ raise sys.exc_type, sys.exc_value, sys.exc_traceback
+ self.__connect(True)
+ self.modify(dn, mods, retries - 1)
+
+
+class Static:
+ def __init__(self, func):
+ self.__call__ = func
+
+class Tags:
+ def __init__(self, database, tags):
+ self.database = database
+ self.tags = tags
+
+ def __refresh(self, force = False):
+ if not force and self.tags is not None:
+ return
+ try:
+ results = self.database.storage.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE])
+
+ tags = { }
+ for (dn, entry) in results:
+ for attr in entry.values():
+ for value in attr:
+ tags[value] = DN_ATTRIBUTE
+ self.tags = tags
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for keys: %s" % ex.args[0]["desc"])
+
+ def __len__(self):
+ self.__refresh()
+ return len(self.tags)
+ def __getitem__(self, k):
+ self.__refresh()
+ return self.tags[k]
+ def __setitem__(self, k):
+ assert False
+ def __delitem__(self, k):
+ assert False
+ def __contains__(self, k):
+ self.__refresh()
+ return k in self.tags
+ def __iter__(self):
+ self.__refresh()
+ return iter(self.tags)
+ def items(self):
+ self.__refresh()
+ return self.tags.items()
+
+
+ def from_dn(dn):
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except ValueError:
+ raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid dn: %s" % dn)
+ return Tags.from_parsed_dn(parsed)
+ from_dn = Static(from_dn)
+
+ def from_parsed_dn(parsed):
+ tags = { }
+ for (typ, val, num) in parsed[0]:
+ tags[val] = typ
+ return Tags(None, tags)
+ from_parsed_dn = Static(from_parsed_dn)
+
+ def from_database(database):
+ return Tags(database, None)
+ from_database = Static(from_database)
+
+
+def is_parsed_dn_parent (dn, parent):
+ if len(dn) <= len(parent):
+ return False
+ # Go backwards and validate each parent
+ for i in range(-1, -1 - len(parent)):
+ print i, dn[i], parent[i]
+ if dn[i] != parent[i]:
+ return False
+ return True
+
+class Database(Backend.Database):
+ def __init__(self, suffix):
+ Backend.Database.__init__(self, suffix)
+ self.storage = Storage(HOST)
+ try:
+ self.suffix_dn = ldap.dn.str2dn(self.suffix)
+ except ValueError:
+ raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid suffix dn")
+
+ def __search_tag_keys(self, tags):
+
+ if not len(tags):
+ return []
+
+ # Build up a filter
+ filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags]
+ if len(filter) > 1:
+ filter = "(&" + "".join(filter) + ")"
+ else:
+ filter = filter[0]
+
+ try:
+ # Search for all those guys
+ results = self.storage.search(BASE, filter, [ KEY_ATTRIBUTE ])
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for tags: %s" % ex.args[0]["desc"])
+
+ return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]]
+
+
+ def __search_key_dns(self, key):
+
+ # Build up a filter
+ filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key))
+
+ try:
+ # Do the actual search
+ results = self.storage.search(BASE, filter)
+
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't search ldap for keys: %s" % ex.args[0]["desc"])
+
+ return [dn for (dn, entry) in results if dn]
+
+
+ def __build_root_entry(self, tags, any_attrs = True):
+ attrs = {
+ "objectClass" : [ "top" ],
+ "hasSubordinates" : [ ]
+ }
+
+ for (typ, val, num) in self.suffix_dn[0]:
+ if not attrs.has_key(typ):
+ attrs[typ] = [ ]
+ attrs[typ].append(val)
+
+ # Note that we don't access 'tags' unless attributes requested
+ if any_attrs:
+ attrs["hasSubordinates"].append(tags and "FALSE" or "TRUE")
+
+ return (self.suffix, attrs)
+
+
+ def __build_pivot_entry(self, tags, any_attrs = True):
+ attrs = {
+ REF_ATTRIBUTE : [ ],
+ "objectClass" : [ OBJECT_CLASS ],
+ "hasSubordinates" : [ "FALSE" ]
+ }
+
+ # Build up a DN, and relevant attrs
+ rdn = []
+ for tag, typ in tags.items():
+ rdn.append((typ, tag, 1))
+ if not attrs.has_key(typ):
+ attrs[typ] = [ ]
+ attrs[typ].append(tag)
+ dn = [ rdn ]
+ dn.extend(self.suffix_dn)
+ dn = ldap.dn.dn2str(dn)
+
+ # Get out all the attributes
+ if any_attrs:
+ for key in self.__search_tag_keys(tags):
+ attrs[REF_ATTRIBUTE].append(key)
+
+ return (dn, attrs)
+
+
+
+ def __limit_results(self, args, results):
+ # TODO: Support sizelimit
+ # TODO: Support a filter
+
+ # Only return the attribute names?
+ if args["attrsonly"] == "1":
+ for (dn, attrs) in results:
+ for attr in attrs:
+ attrs[attr] = [ "" ]
+
+ # Only return these attributes?
+ which = args["attrs"]
+ if which != "all" and which != "*" and which != "+":
+ which = which.split(" ")
+ for (dn, attrs) in results:
+ for attr in attrs.keys():
+ if attr not in which:
+ del attrs[attr]
+
+
+ def search(self, dn, args):
+ results = []
+
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except:
+ raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in search: %s" % dn)
+
+ # Arguments sent
+ scope = args["scope"] or SCOPE_BASE
+ any_attrs = len(args["attrs"].strip()) > 0
+
+ # Start at the root
+ if parsed == self.suffix_dn:
+ tags = Tags.from_database(self)
+ if scope == SCOPE_BASE or scope == SCOPE_SUB:
+ results.append(self.__build_root_entry(tags, any_attrs))
+ if scope == SCOPE_ONE or scope == SCOPE_SUB:
+ # Process each tag individually, by default
+ for (tag, typ) in tags.items():
+ results.append(self.__build_pivot_entry({ tag : typ }, any_attrs))
+
+ # Start at a tag
+ elif is_parsed_dn_parent (parsed, self.suffix_dn):
+ tags = Tags.from_parsed_dn(parsed)
+ if scope == SCOPE_BASE or scope == SCOPE_SUB:
+ results.append(self.__build_pivot_entry(tags, any_attrs))
+
+ # We don't have that base
+ else:
+ raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn)
+
+ self.__limit_results(args, results)
+ return results
+
+
+ def __build_key_mods(self, key, tags, op, mods):
+ dns = self.__search_key_dns(key)
+ if not dns:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key))
+ for dn in dns:
+ if not mods.has_key(dn):
+ mods[dn] = (key, [])
+ for tag in tags:
+ mods[dn][1].append((op, TAG_ATTRIBUTE, tag))
+
+ def modify(self, dn, mods):
+
+ try:
+ parsed = ldap.dn.str2dn(dn)
+ except:
+ raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in modify: %s" % dn)
+
+ if dn == self.suffix:
+ raise Backend.Error(Backend.INSUFFICIENT_ACCESS,
+ "Cannot modify root dn of pivot area: %s" % dn)
+
+ if not is_parsed_dn_parent (parsed, self.suffix_dn):
+ raise Backend.Error(Backend.NO_SUCH_OBJECT,
+ "DN '%s' does not exist" % dn)
+
+ tags = Tags.from_parsed_dn(parsed)
+
+ add_keys = sets.Set()
+ remove_keys = sets.Set()
+ remove_all = False
+
+ # Parse out all the adds and removes
+ for (op, attr, value) in mods:
+
+ if attr != REF_ATTRIBUTE:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Cannot modify '%s' attribute" % attr)
+
+ if op == ldap.MOD_ADD:
+ if value:
+ add_keys.add(value)
+ elif op == ldap.MOD_REPLACE:
+ remove_all = True
+ if value:
+ add_keys.add(value)
+ elif op == ldap.MOD_DELETE:
+ if value:
+ remove_keys.add(value)
+ else:
+ remove_all = True
+ else:
+ continue
+
+ # Remove all of the ref attribute
+ if remove_all:
+ for key in self.__search_tag_keys (tags):
+ remove_keys.add(key)
+
+ # Make them all unique, and non conflicting
+ for key in add_keys.copy():
+ if key in remove_keys:
+ remove_keys.remove(key)
+ add_keys.remove(key)
+
+ # Change them to DNs, and build mods objects for each dn
+ keys_and_mods_by_dn = { }
+ for key in add_keys:
+ self.__build_key_mods(key, tags, ldap.MOD_ADD, keys_and_mods_by_dn)
+ for key in remove_keys:
+ self.__build_key_mods(key, tags, ldap.MOD_DELETE, keys_and_mods_by_dn)
+
+
+ # Now perform the actual actions, combining errors
+ errors = []
+ for (dn, (key, mod)) in keys_and_mods_by_dn.items():
+ try:
+ print dn, mod
+ self.storage.modify(dn, mod)
+ except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE):
+ continue
+ except ldap.NO_SUCH_OBJECT:
+ errors.append(key)
+ except ldap.LDAPError, ex:
+ raise Backend.Error(Backend.OPERATIONS_ERROR,
+ "Couldn't perform one of the modifications: %s" % ex.args[0]["desc"])
+
+ # Send back errors
+ if errors:
+ raise Backend.Error(Backend.CONSTRAINT_VIOLATION,
+ "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors)))
+
+
diff --git a/slapd-pivot.py b/slapd-pivot.py
new file mode 100644
index 0000000..7eb1dab
--- /dev/null
+++ b/slapd-pivot.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+
+import sys
+import Backend, Pivot
+
+def run_server():
+ Backend.debug = True
+ server = Backend.Server("/tmp/pivot-slapd.sock", Pivot.Database)
+
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+
+if __name__ == '__main__':
+ run_server()
+