From ea4a294b84254fb19b2edd24683fa6501970d049 Mon Sep 17 00:00:00 2001 From: Stef Walter Date: Fri, 6 Jun 2008 17:35:32 +0000 Subject: Initial import --- Backend.py | 393 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Pivot.py | 397 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ slapd-pivot.py | 18 +++ 3 files changed, 808 insertions(+) create mode 100644 Backend.py create mode 100644 Pivot.py create mode 100644 slapd-pivot.py diff --git a/Backend.py b/Backend.py new file mode 100644 index 0000000..6dc5497 --- /dev/null +++ b/Backend.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python + +from __future__ import with_statement +import sys, os, time, re +import threading, mutex +import SocketServer, StringIO +import ldap, ldif + +debug = 0 + +OPERATIONS_ERROR = 0x01 +PROTOCOL_ERROR = 0x02 +TIMELIMIT_EXCEEDED = 0x03 +SIZELIMIT_EXCEEDED = 0x04 +COMPARE_FALSE = 0x05 +COMPARE_TRUE = 0x06 +AUTH_METHOD_NOT_SUPPORTED = 0x07 +STRONG_AUTH_NOT_SUPPORTED = 0x07 +STRONG_AUTH_REQUIRED = 0x08 +STRONGER_AUTH_REQUIRED = 0x08 +PARTIAL_RESULTS = 0x09 +ADMINLIMIT_EXCEEDED = 0x0b +CONFIDENTIALITY_REQUIRED = 0x0d +SASL_BIND_IN_PROGRESS = 0x0e +NO_SUCH_ATTRIBUTE = 0x10 +UNDEFINED_TYPE = 0x11 +INAPPROPRIATE_MATCHING = 0x12 +CONSTRAINT_VIOLATION = 0x13 +TYPE_OR_VALUE_EXISTS = 0x14 +INVALID_SYNTAX = 0x15 +NO_SUCH_OBJECT = 0x20 +ALIAS_PROBLEM = 0x21 +INVALID_DN_SYNTAX = 0x22 +IS_LEAF = 0x23 +ALIAS_DEREF_PROBLEM = 0x24 +X_PROXY_AUTHZ_FAILURE = 0x2F +INAPPROPRIATE_AUTH = 0x30 +INVALID_CREDENTIALS = 0x31 +INSUFFICIENT_ACCESS = 0x32 +BUSY = 0x33 +UNAVAILABLE = 0x34 +UNWILLING_TO_PERFORM = 0x35 +LOOP_DETECT = 0x36 +NAMING_VIOLATION = 0x40 +OBJECT_CLASS_VIOLATION = 0x41 +NOT_ALLOWED_ON_NONLEAF = 0x42 +NOT_ALLOWED_ON_RDN = 0x43 +ALREADY_EXISTS = 0x44 +NO_OBJECT_CLASS_MODS = 0x45 +RESULTS_TOO_LARGE = 0x46 +AFFECTS_MULTIPLE_DSAS = 0x47 +OTHER_ERROR = 0x50 + +def split_argument(line): + parts = line.strip().split(':', 2) + name = parts[0].strip() + value = "" + if len(parts) == 2: + value = parts[1].strip() + return (name, value) + + +class Parser(ldif.LDIFParser): + def __init__(self, input): + ldif.LDIFParser.__init__(self, input) + self.dn = None + self.entry = None + def handle(self, dn, entry): + if self.entry is None: + self.dn = dn + self.entry = entry + + +class Database: + def __init__(self, suffix): + self.suffix = suffix + self.binddn = None + self.remote = None + self.mutex = threading.Lock() + + # Overridable to handle specific command + def add(self, dn, entry): + raise VirtualError, (UNWILLING_TO_PERFORM, "Add not implemented") + def bind(self, dn, args): + raise VirtualError, (UNWILLING_TO_PERFORM, "Bind not implemented") + def compare(self, dn, entry): + raise VirtualError, (UNWILLING_TO_PERFORM, "Compare not implemented") + def delete(self, dn, args): + raise VirtualError, (UNWILLING_TO_PERFORM, "Delete not implemented") + def modify(self, dn, modlist): + raise VirtualError, (UNWILLING_TO_PERFORM, "Modify not implemented") + def modrdn(self, dn, args): + raise VirtualError, (UNWILLING_TO_PERFORM, "ModRDN not implemented") + def search(self, dn, args): + raise VirtualError, (UNWILLING_TO_PERFORM, "Search not implemented") + + # Overridable to handle all processing + def process(self, command, dn, block): + try: + + # This we handle specially + if command == "MODIFY": + return self.process_modify_internal(dn, block) + + # This we handle specially + elif command == "ADD": + return self.process_add_internal(dn, block) + + # All the rest we split up, multiple args go into arrays + args = Arguments() + while True: + line = block.readline() + if len(line) == 0: + break + (name, value) = split_argument(line) + args.add(name, value) + + if command == "BIND": + self.bind(dn, args) + elif command == "COMPARE": + result = self.compare(dn, args) + return (result, "", None) + elif command == "DELETE": + self.delete(dn, args) + elif command == "MODIFY": + self.modify(dn, args) + elif command == "MODRDN": + self.modrdn(dn, args) + elif command == "SEARCH": + return self.process_search_internal(dn, args) + elif command == "UNBIND": + assert False # should have been handled in caller + else: + return UNWILLING_TO_PERFORM, "Unsupported operation %s" % command + + return (0, "", []) + + except Error, ex: + return (ex.code, ex.info, None) + + def process_locked(self, command, dn, block, binddn, remote): + with self.mutex: + self.binddn = binddn + self.remote = remote + result = self.process(command, dn, block) + return result + + def process_add_internal(self, dn, block): + parser = Parser(block) + parser.parse() + self.add(dn, parser.entry) + return (0, "", []) + + + def process_compare_intersal(self, dn, block): + parser = Parser(block) + parser.parse() + result = self.compare(dn, parser.entry) + return (result, "", []) + + + def process_search_internal(self, dn, args): + results = [] + data = self.search(args["base"] or dn, args) + for (dn, entry) in data: + result = StringIO.StringIO() + writer = ldif.LDIFWriter(result) + writer.unparse(dn, entry) + results.append(result.getvalue()) + return (0, "", results) + + + def process_modify_internal(self, dn, block): + + op = None + attr = None + batch = 0 + mods = [] + + while True: + line = block.readline() + if len(line) == 0: + break + line = line.strip() + + # A break between different mods + if line == "-": + + # Latch batch was empty, delete/replace all + if batch == 0: + if op == ldap.MOD_DELETE: + mods.append((op, attr, None)) + elif op == ldap.MOD_REPLACE: + mods.append((op, attr, None)) + + batch = 0 + op = None + attr = None + continue + + # The current line + (name, value) = split_argument(line) + + # Don't have a mod type yet + if op is None: + attr = value + if name == "add": + op = ldap.MOD_ADD + elif name == "replace": + op = ldap.MOD_REPLACE + elif name == "delete": + op = ldap.MOD_DELETE + else: + op = None + + # Have a op, add values + elif name == attr: + assert attr is not None + mods.append((op, attr, value)) + batch += 1 + + print mods + self.modify(dn, mods) + return (0, "", []) + + +class Error(Exception): + """Exception to be returned to server""" + + def __init__(self, code = OPERATIONS_ERROR, info = ""): + self.code = code + self.info = info + + def __str__(self): + return "%d: %s" % (self.code, self.info) + +class Arguments: + def __init__(self): + self.dict = {} + def add(self, name, value): + if self.dict.has_key(name): + if type(self.dict[name]) != type([]): + self.dict[name] = [self.dict[name]] + self.dict[name].append(value) + else: + self.dict[name] = value + + self.dict[name] = value + def __getitem__(self, name): + if self.dict.has_key(name): + return self.dict[name] + return None + def __len__(self): + return len(self.dict) + +class Connection(SocketServer.BaseRequestHandler): + + def __init__(self, request, client_address, server): + server.unique += 1 + self.identifier = server.unique + self.block_regex = re.compile("\r?\n[ \t]*\r?\n") + SocketServer.BaseRequestHandler.__init__(self, request, client_address, server) + + def trace(self, message): + global debug + if debug: + prefix = "%04d " % self.identifier + lines = message.split("\n") + print >> sys.stderr, prefix + lines[0] + print >> sys.stderr, "\n".join([prefix + "*** " + line for line in lines[1:]]) + + def handle(self): + + self.trace("CONNECTED") + + req = self.request + req.setblocking(1) + req.settimeout(None) + + extra = "" + block = None + while True: + data = extra + req.recv(1024) + + # End of connection + if len(data) == 0: + break + + parts = self.block_regex.split(data) + if len(parts) > 1: + block = unicode(parts[0], "utf-8", "strict") + break + extra = parts[0] + + if block: + self.trace("REQUEST\n%s\n" % block) + self.handle_block(req, StringIO.StringIO(block)) + + self.trace("DISCONNECTING") + self.request.close() + + def handle_block(self, req, block): + + line = block.readline() + command = line.strip() + + # Disconnect immediately on certain occasions + if not command: + return False + elif command == "UNBIND": + return False + + suffixes = [] + binddn = None + remote = None + ssf = None + msgid = None + dn = None + + while True: + off = block.tell() + + line = block.readline() + (name, value) = split_argument(line) + + if name == "suffix": + suffixes.append(value) + elif name == "msgid" and msgid is None: + msgid = value + elif name == "dn" and dn is None: + dn = value.lower() + elif name == "binddn" and binddn is None: + binddn = value.lower() + elif name == "peername" and remote is None: + remote = value + elif name == "ssf" and ssf is None: + ssf = value + else: # Return this line and continue + block.seek(off) + break + + code = 0 + info = "" + data = None + + if len(suffixes) == 0: + code = OPERATIONS_ERROR + info = "No suffix specified" + elif len(suffixes) > 1: + code = OPERATIONS_ERROR + info = "Multiple suffixes not supported" + else: + database = self.server.find_database(suffixes[0]) + (code, info, data) = database.process_locked(command, dn, block, binddn, remote) + + if data: + for dat in data: + self.trace("DATA\n%s\n" % dat) + req.sendall(dat.strip("\n") + "\n\n") + + result = "RESULT" + if info: + result += "\ninfo: %s" % info + # BUG: Current OpenLDAP always wants code last + result += "\ncode: %d" % code + self.trace("RESPONSE\n%s" % result.strip("\n")) + req.sendall(result) + + return True + + +class Server(SocketServer.ThreadingMixIn, SocketServer.UnixStreamServer): + daemon_threads = True + allow_reuse_address = True + + def __init__(self, address, DatabaseClass, debug = False): + if (os.path.exists(address)): + os.unlink(address) + SocketServer.UnixStreamServer.__init__(self, address, Connection) + self.DatabaseClass = DatabaseClass + self.__databases = {} + self.unique = 0 + + def find_database(self, suffix): + suffix = suffix.lower() + if self.__databases.has_key(suffix): + database = self.__databases[suffix] + else: + database = self.DatabaseClass(suffix) + self.__databases[suffix] = database + return database + + diff --git a/Pivot.py b/Pivot.py new file mode 100644 index 0000000..28cea92 --- /dev/null +++ b/Pivot.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import ldap, ldap.dn, ldap.filter +import Backend, sets + +HOST = "ldap://localhost:3890" +BINDDN = "cn=root,dc=fam" +PASSWORD = "barn" +BASE = "dc=fam" +REF_ATTRIBUTE = "member" +KEY_ATTRIBUTE = "uid" +TAG_ATTRIBUTE = "memberOf" + +OBJECT_CLASS = "groupOfNames" +DN_ATTRIBUTE = "cn" + +# hasSubordinates: TRUE + +SCOPE_BASE = "0" +SCOPE_ONE = "1" +SCOPE_SUB = "2" + +class Storage: + def __init__(self, url): + self.url = url + self.ldap = None + + def __connect(self, force = False): + if self.ldap: + if not force: + return + self.ldap.unbind() + self.ldap = ldap.initialize(self.url) + try: + self.ldap.simple_bind_s(BINDDN, PASSWORD) + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't do internal authenticate: %s" % ex.args[0]["desc"]) + + def read(self, dn, filtstr = "(objectClass=*)", attrs = None, retries = 1): + try: + self.__connect() + results = self.ldap.search_s(dn, ldap.SCOPE_BASE, filtstr, attrs, 0) + + if not results: + return None + (dn, entry) = results[0] + return entry + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + return self.read(dn, filtstr, attrs, retries - 1) + + def search(self, base, filtstr, attrs = [], retries = 1): + try: + self.__connect() + return self.ldap.search_s(base, ldap.SCOPE_SUBTREE, filtstr, attrs, 0) + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + return self.search(base, filtstr, attrs, retries - 1) + + def modify(self, dn, mods, retries = 1): + try: + self.__connect() + self.ldap.modify_s(dn, mods) + + except ldap.SERVER_DOWN, ex: + if retries <= 0: + raise sys.exc_type, sys.exc_value, sys.exc_traceback + self.__connect(True) + self.modify(dn, mods, retries - 1) + + +class Static: + def __init__(self, func): + self.__call__ = func + +class Tags: + def __init__(self, database, tags): + self.database = database + self.tags = tags + + def __refresh(self, force = False): + if not force and self.tags is not None: + return + try: + results = self.database.storage.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE]) + + tags = { } + for (dn, entry) in results: + for attr in entry.values(): + for value in attr: + tags[value] = DN_ATTRIBUTE + self.tags = tags + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) + + def __len__(self): + self.__refresh() + return len(self.tags) + def __getitem__(self, k): + self.__refresh() + return self.tags[k] + def __setitem__(self, k): + assert False + def __delitem__(self, k): + assert False + def __contains__(self, k): + self.__refresh() + return k in self.tags + def __iter__(self): + self.__refresh() + return iter(self.tags) + def items(self): + self.__refresh() + return self.tags.items() + + + def from_dn(dn): + try: + parsed = ldap.dn.str2dn(dn) + except ValueError: + raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid dn: %s" % dn) + return Tags.from_parsed_dn(parsed) + from_dn = Static(from_dn) + + def from_parsed_dn(parsed): + tags = { } + for (typ, val, num) in parsed[0]: + tags[val] = typ + return Tags(None, tags) + from_parsed_dn = Static(from_parsed_dn) + + def from_database(database): + return Tags(database, None) + from_database = Static(from_database) + + +def is_parsed_dn_parent (dn, parent): + if len(dn) <= len(parent): + return False + # Go backwards and validate each parent + for i in range(-1, -1 - len(parent)): + print i, dn[i], parent[i] + if dn[i] != parent[i]: + return False + return True + +class Database(Backend.Database): + def __init__(self, suffix): + Backend.Database.__init__(self, suffix) + self.storage = Storage(HOST) + try: + self.suffix_dn = ldap.dn.str2dn(self.suffix) + except ValueError: + raise Backend.Error(Backend.Error.PROTOCOL_ERROR, "invalid suffix dn") + + def __search_tag_keys(self, tags): + + if not len(tags): + return [] + + # Build up a filter + filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags] + if len(filter) > 1: + filter = "(&" + "".join(filter) + ")" + else: + filter = filter[0] + + try: + # Search for all those guys + results = self.storage.search(BASE, filter, [ KEY_ATTRIBUTE ]) + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for tags: %s" % ex.args[0]["desc"]) + + return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]] + + + def __search_key_dns(self, key): + + # Build up a filter + filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key)) + + try: + # Do the actual search + results = self.storage.search(BASE, filter) + + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't search ldap for keys: %s" % ex.args[0]["desc"]) + + return [dn for (dn, entry) in results if dn] + + + def __build_root_entry(self, tags, any_attrs = True): + attrs = { + "objectClass" : [ "top" ], + "hasSubordinates" : [ ] + } + + for (typ, val, num) in self.suffix_dn[0]: + if not attrs.has_key(typ): + attrs[typ] = [ ] + attrs[typ].append(val) + + # Note that we don't access 'tags' unless attributes requested + if any_attrs: + attrs["hasSubordinates"].append(tags and "FALSE" or "TRUE") + + return (self.suffix, attrs) + + + def __build_pivot_entry(self, tags, any_attrs = True): + attrs = { + REF_ATTRIBUTE : [ ], + "objectClass" : [ OBJECT_CLASS ], + "hasSubordinates" : [ "FALSE" ] + } + + # Build up a DN, and relevant attrs + rdn = [] + for tag, typ in tags.items(): + rdn.append((typ, tag, 1)) + if not attrs.has_key(typ): + attrs[typ] = [ ] + attrs[typ].append(tag) + dn = [ rdn ] + dn.extend(self.suffix_dn) + dn = ldap.dn.dn2str(dn) + + # Get out all the attributes + if any_attrs: + for key in self.__search_tag_keys(tags): + attrs[REF_ATTRIBUTE].append(key) + + return (dn, attrs) + + + + def __limit_results(self, args, results): + # TODO: Support sizelimit + # TODO: Support a filter + + # Only return the attribute names? + if args["attrsonly"] == "1": + for (dn, attrs) in results: + for attr in attrs: + attrs[attr] = [ "" ] + + # Only return these attributes? + which = args["attrs"] + if which != "all" and which != "*" and which != "+": + which = which.split(" ") + for (dn, attrs) in results: + for attr in attrs.keys(): + if attr not in which: + del attrs[attr] + + + def search(self, dn, args): + results = [] + + try: + parsed = ldap.dn.str2dn(dn) + except: + raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in search: %s" % dn) + + # Arguments sent + scope = args["scope"] or SCOPE_BASE + any_attrs = len(args["attrs"].strip()) > 0 + + # Start at the root + if parsed == self.suffix_dn: + tags = Tags.from_database(self) + if scope == SCOPE_BASE or scope == SCOPE_SUB: + results.append(self.__build_root_entry(tags, any_attrs)) + if scope == SCOPE_ONE or scope == SCOPE_SUB: + # Process each tag individually, by default + for (tag, typ) in tags.items(): + results.append(self.__build_pivot_entry({ tag : typ }, any_attrs)) + + # Start at a tag + elif is_parsed_dn_parent (parsed, self.suffix_dn): + tags = Tags.from_parsed_dn(parsed) + if scope == SCOPE_BASE or scope == SCOPE_SUB: + results.append(self.__build_pivot_entry(tags, any_attrs)) + + # We don't have that base + else: + raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn) + + self.__limit_results(args, results) + return results + + + def __build_key_mods(self, key, tags, op, mods): + dns = self.__search_key_dns(key) + if not dns: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key)) + for dn in dns: + if not mods.has_key(dn): + mods[dn] = (key, []) + for tag in tags: + mods[dn][1].append((op, TAG_ATTRIBUTE, tag)) + + def modify(self, dn, mods): + + try: + parsed = ldap.dn.str2dn(dn) + except: + raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn in modify: %s" % dn) + + if dn == self.suffix: + raise Backend.Error(Backend.INSUFFICIENT_ACCESS, + "Cannot modify root dn of pivot area: %s" % dn) + + if not is_parsed_dn_parent (parsed, self.suffix_dn): + raise Backend.Error(Backend.NO_SUCH_OBJECT, + "DN '%s' does not exist" % dn) + + tags = Tags.from_parsed_dn(parsed) + + add_keys = sets.Set() + remove_keys = sets.Set() + remove_all = False + + # Parse out all the adds and removes + for (op, attr, value) in mods: + + if attr != REF_ATTRIBUTE: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Cannot modify '%s' attribute" % attr) + + if op == ldap.MOD_ADD: + if value: + add_keys.add(value) + elif op == ldap.MOD_REPLACE: + remove_all = True + if value: + add_keys.add(value) + elif op == ldap.MOD_DELETE: + if value: + remove_keys.add(value) + else: + remove_all = True + else: + continue + + # Remove all of the ref attribute + if remove_all: + for key in self.__search_tag_keys (tags): + remove_keys.add(key) + + # Make them all unique, and non conflicting + for key in add_keys.copy(): + if key in remove_keys: + remove_keys.remove(key) + add_keys.remove(key) + + # Change them to DNs, and build mods objects for each dn + keys_and_mods_by_dn = { } + for key in add_keys: + self.__build_key_mods(key, tags, ldap.MOD_ADD, keys_and_mods_by_dn) + for key in remove_keys: + self.__build_key_mods(key, tags, ldap.MOD_DELETE, keys_and_mods_by_dn) + + + # Now perform the actual actions, combining errors + errors = [] + for (dn, (key, mod)) in keys_and_mods_by_dn.items(): + try: + print dn, mod + self.storage.modify(dn, mod) + except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE): + continue + except ldap.NO_SUCH_OBJECT: + errors.append(key) + except ldap.LDAPError, ex: + raise Backend.Error(Backend.OPERATIONS_ERROR, + "Couldn't perform one of the modifications: %s" % ex.args[0]["desc"]) + + # Send back errors + if errors: + raise Backend.Error(Backend.CONSTRAINT_VIOLATION, + "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors))) + + diff --git a/slapd-pivot.py b/slapd-pivot.py new file mode 100644 index 0000000..7eb1dab --- /dev/null +++ b/slapd-pivot.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +import sys +import Backend, Pivot + +def run_server(): + Backend.debug = True + server = Backend.Server("/tmp/pivot-slapd.sock", Pivot.Database) + + try: + server.serve_forever() + except KeyboardInterrupt: + sys.exit(0) + + +if __name__ == '__main__': + run_server() + -- cgit v1.2.3