From 390be307a7424a60db0170e4c174332b23e42cbb Mon Sep 17 00:00:00 2001 From: Stef Walter Date: Fri, 6 Jun 2008 22:00:04 +0000 Subject: Add cofiguration, daemon, and fix storage bugs. --- Backend.py | 1 - Config.py | 35 +++++++++++++ Pivot.py | 155 ++++++++++++++++++++++++++++++--------------------------- slapd-pivot.py | 137 ++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 250 insertions(+), 78 deletions(-) create mode 100644 Config.py diff --git a/Backend.py b/Backend.py index 6dc5497..8585965 100644 --- a/Backend.py +++ b/Backend.py @@ -219,7 +219,6 @@ class Database: mods.append((op, attr, value)) batch += 1 - print mods self.modify(dn, mods) return (0, "", []) diff --git a/Config.py b/Config.py new file mode 100644 index 0000000..47009f2 --- /dev/null +++ b/Config.py @@ -0,0 +1,35 @@ + +import os +import ConfigParser + +__config = None +SECTION = "main" + +class Error(Exception): + """Configuration error""" + def __init__(self, value): + self.value = value + self.str = repr(self.value) + def __str__(self): + return self.str + +def load(filename): + global __config + conf = ConfigParser.ConfigParser() + conf.read(filename) + if not conf.has_section(SECTION): + raise Error, "invalid or missing config file: %s" % filename + __config = conf + +def require(key): + result = option(key) + if not result: + raise Error, "missing conf option '%s' in section '%s'" % (key, SECTION) + return result + +def option(key, default = None): + assert __config is not None, "configuration not loaded" + if not __config.has_option(SECTION, key): + return default + return __config.get(SECTION, key) + diff --git a/Pivot.py b/Pivot.py index 12cdfb7..5f02d92 100644 --- a/Pivot.py +++ b/Pivot.py @@ -2,31 +2,19 @@ import sys, os, sets import ldap, ldap.dn, ldap.filter, ldif -import Backend - -HOST = "ldap://localhost:3890" -ROOTDN = "cn=root,dc=fam" -PASSWORD = "barn" -BASE = "dc=fam" -REF_ATTRIBUTE = "member" -KEY_ATTRIBUTE = "uid" -TAG_ATTRIBUTE = "memberOf" -ACCESS_ATTRIBUTE = "access" -FILENAME = "/tmp/pivot.ldif" - -OBJECT_CLASS = "group" -DN_ATTRIBUTE = "cn" - -# hasSubordinates: TRUE +import Backend, Config SCOPE_BASE = "0" SCOPE_ONE = "1" SCOPE_SUB = "2" class Lookups: - def __init__(self, url): + def __init__(self, url, binddn, password): self.url = url self.ldap = None + self.binddn = binddn + self.password = password + def __connect(self, force = False): if self.ldap: @@ -35,7 +23,7 @@ class Lookups: self.ldap.unbind() self.ldap = ldap.initialize(self.url) try: - self.ldap.simple_bind_s(ROOTDN, PASSWORD) + self.ldap.simple_bind_s(self.binddn, self.password) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't do internal authenticate: %s" % ex.args[0]["desc"]) @@ -80,12 +68,14 @@ class Lookups: class Storage: - def __init__(self, filename): + def __init__(self, filename = None): self.filename = filename self.entries = { } self.load() def load(self): + if not self.filename: + return if not os.path.exists(self.filename): return input = open(self.filename, 'r') @@ -97,16 +87,14 @@ class Storage: self.entries[dn] = entry def save(self): + if not self.filename: + return output = open(self.filename, 'w') print >> output, "# Overwritten automatically, do not edit\n" writer = ldif.LDIFWriter(output) for (dn, entry) in self.entries.items(): if (entry): - print - print dn - print repr(entry) - print writer.unparse(dn, entry) output.close() @@ -174,13 +162,15 @@ class Tags: if not force and self.tags is not None: return try: - results = self.database.lookups.search(BASE, "(%s=*)" % TAG_ATTRIBUTE, [TAG_ATTRIBUTE]) + results = self.database.lookups.search(self.database.search_base, + "(%s=*)" % self.database.tag_attribute, + [self.database.tag_attribute]) tags = { } for (dn, entry) in results: for attr in entry.values(): for value in attr: - tags[value] = DN_ATTRIBUTE + tags[value] = self.database.dn_attribute self.tags = tags except ldap.LDAPError, ex: @@ -235,13 +225,10 @@ def parse_dn(dn): raise Backend.Error(Backend.PROTOCOL_ERROR, "Invalid dn: %s" % dn) def is_parsed_dn_parent(dn, parent): - print dn - print parent if len(dn) != len(parent) + 1: return False # Go backwards and validate each parent for i in range(-1, -1 - len(parent)): - print i, dn[i], parent[i] if dn[i] != parent[i]: return False return True @@ -249,9 +236,23 @@ def is_parsed_dn_parent(dn, parent): class Database(Backend.Database): def __init__(self, suffix): Backend.Database.__init__(self, suffix) - self.lookups = Lookups(HOST) + + self.rootdn = Config.require("ldap-root") + self.search_base = Config.require("ldap-base") + self.lookups = Lookups(Config.require("ldap-host"), self.rootdn, + Config.require("ldap-password")) self.suffix_dn = parse_dn(self.suffix) - self.storage = Storage(FILENAME) + + self.dn_attribute = Config.require("rdn-attribute") + self.object_class = Config.require("ref-objectclass") + self.ref_attribute = Config.require("ref-attribute") + self.key_attribute = Config.require("key-attribute") + self.access_attribute = Config.require("access-attribute") + self.tag_attribute = Config.require("tag-attribute") + + filename = Config.option("storage-file") + self.storage = Storage(filename) + def __search_tag_keys(self, tags): @@ -259,7 +260,7 @@ class Database(Backend.Database): return [] # Build up a filter - filter = [ldap.filter.filter_format("(%s=%s)", (TAG_ATTRIBUTE, tag)) for tag in tags] + filter = [ldap.filter.filter_format("(%s=%s)", (self.tag_attribute, tag)) for tag in tags] if len(filter) > 1: filter = "(&" + "".join(filter) + ")" else: @@ -267,23 +268,23 @@ class Database(Backend.Database): try: # Search for all those guys - results = self.lookups.search(BASE, filter, [ KEY_ATTRIBUTE ]) + results = self.lookups.search(self.search_base, filter, [ self.key_attribute ]) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, "Couldn't search ldap for tags: %s" % ex.args[0]["desc"]) - return [entry[KEY_ATTRIBUTE][0] for (dn, entry) in results if entry[KEY_ATTRIBUTE]] + return [entry[self.key_attribute][0] for (dn, entry) in results if entry[self.key_attribute]] def __search_key_dns(self, key): # Build up a filter - filter = ldap.filter.filter_format("(%s=%s)", (KEY_ATTRIBUTE, key)) + filter = ldap.filter.filter_format("(%s=%s)", (self.key_attribute, key)) try: # Do the actual search - results = self.lookups.search(BASE, filter) + results = self.lookups.search(self.search_base, filter) except ldap.LDAPError, ex: raise Backend.Error(Backend.OPERATIONS_ERROR, @@ -292,7 +293,7 @@ class Database(Backend.Database): return [dn for (dn, entry) in results if dn] - def __build_root_entry(self, tags, with_attrs = True): + def __build_root_entry(self, tags): attrs = { "objectClass" : [ "top" ], "hasSubordinates" : [ ] @@ -303,17 +304,14 @@ class Database(Backend.Database): attrs[typ] = [ ] attrs[typ].append(val) - # Note that we don't access 'tags' unless attributes requested - if with_attrs: - attrs["hasSubordinates"].append(tags and "TRUE" or "FALSE") - + attrs["hasSubordinates"].append(tags and "TRUE" or "FALSE") return (self.suffix, attrs) - def __build_pivot_entry(self, tags, with_attrs = True): + def __build_pivot_entry(self, tags, keys): attrs = { - REF_ATTRIBUTE : [ ], - "objectClass" : [ OBJECT_CLASS ], + self.ref_attribute : [ ], + "objectClass" : [ self.object_class ], "hasSubordinates" : [ "FALSE" ] } @@ -328,16 +326,16 @@ class Database(Backend.Database): dn.extend(self.suffix_dn) dn = ldap.dn.dn2str(dn) - # Get out all the attributes - if with_attrs: - for key in self.__search_tag_keys(tags): - attrs[REF_ATTRIBUTE].append(key) + for key in keys: + attrs[self.ref_attribute].append(key) return (dn, attrs) - def __build_storage_entry(self, parsed_dn, with_attrs = True): - attrs = { } + def __build_storage_entry(self, parsed_dn, keys): + attrs = { + self.ref_attribute : [ ] + } # Build up DN relevant attrs for (typ, val, num) in parsed_dn[0]: @@ -345,6 +343,9 @@ class Database(Backend.Database): attrs[typ] = [ ] attrs[typ].append(val) + for key in keys: + attrs[self.ref_attribute].append(key) + # All other storage attributes retrieved later if necessary return (ldap.dn.dn2str(parsed_dn), attrs) @@ -404,38 +405,47 @@ class Database(Backend.Database): # Arguments sent scope = args["scope"] or SCOPE_BASE - with_attrs = len(args["attrs"].strip()) > 0 # Start at the root if parsed == self.suffix_dn: tags = Tags.from_database(self) if scope == SCOPE_BASE or scope == SCOPE_SUB: - (dn, entry) = self.__build_root_entry(tags, with_attrs) + (dn, entry) = self.__build_root_entry(tags) results[dn] = entry if scope == SCOPE_ONE or scope == SCOPE_SUB: # Process each tag individually, by default for (tag, typ) in tags.items(): - (child, entry) = self.__build_pivot_entry({ tag : typ }, with_attrs) + ctags = { tag : typ } + keys = self.__search_tag_keys(ctags) + (child, entry) = self.__build_pivot_entry(ctags, keys) results[child] = entry # Process all extra storage items for child in self.storage.list_dns(): if child not in results: - (child, entry) = self.__build_storage_entry(parse_dn(child)) + cparsed = parse_dn(child) + ctags = Tags.from_parsed_dn(cparsed) + keys = self.__search_tag_keys(ctags) + (child, entry) = self.__build_storage_entry(parse_dn(child), keys) results[child] = entry - # Start at a tag - elif is_parsed_dn_parent(parsed, self.suffix_dn): - tags = Tags.from_parsed_dn(parsed) - if scope == SCOPE_BASE or scope == SCOPE_SUB: - (dn, entry) = self.__build_pivot_entry(tags, with_attrs) - results[dn] = entry - # Something in the database elif self.storage.exists(dn): if scope == SCOPE_BASE or scope == SCOPE_SUB: - (dn, entry) = self.__build_storage_entry(parsed, with_attrs) + tags = Tags.from_parsed_dn(parsed) + keys = self.__search_tag_keys(tags) + (dn, entry) = self.__build_storage_entry(parsed, keys) results[dn] = entry + # Start at a tag + elif is_parsed_dn_parent(parsed, self.suffix_dn): + if scope == SCOPE_BASE or scope == SCOPE_SUB: + tags = Tags.from_parsed_dn(parsed) + keys = self.__search_tag_keys(tags) + if keys: + (dn, entry) = self.__build_pivot_entry(tags, keys) + results[dn] = entry + + # We don't have that base else: raise Backend.Error(Backend.NO_SUCH_OBJECT, "DN '%s' does not exist" % dn) @@ -447,18 +457,18 @@ class Database(Backend.Database): dns = self.__search_key_dns(key) if not dns: raise Backend.Error(Backend.CONSTRAINT_VIOLATION, - "Cannot find an entry for %s '%s'" % (KEY_ATTRIBUTE, key)) + "Cannot find an entry for %s '%s'" % (self.key_attribute, key)) for dn in dns: if not mods.has_key(dn): mods[dn] = (key, []) for tag in tags: - mods[dn][1].append((op, TAG_ATTRIBUTE, tag)) + mods[dn][1].append((op, self.tag_attribute, tag)) def __check_write_access(self, dn): - if self.binddn == ROOTDN: + if self.binddn == self.rootdn: return True - return self.storage.has(dn, ACCESS_ATTRIBUTE, self.binddn) + return self.storage.has(dn, self.access_attribute, self.binddn) def add(self, dn, entry): @@ -469,7 +479,7 @@ class Database(Backend.Database): if parsed == self.suffix_dn: raise Backend.Error(Backend.ALREADY_EXISTS, "This entry already exists: %s" % dn) if not is_parsed_dn_parent(parsed, self.suffix_dn): - raise Backend.Error(Backend.NO_SUCH_OBJECT, "Parent of '%s' does not exist or is not valid" % dn) + raise Backend.Error(Backend.NO_SUCH_OBJECT, "Parent of '%s' does not exist or is not a valid place for an entry" % dn) if self.storage.exists(dn): raise Backend.Error(Backend.ALREADY_EXISTS, "This entry already exists: %s" % dn) if len(self.__search_tag_keys(tags)): @@ -484,8 +494,8 @@ class Database(Backend.Database): mods.append((ldap.MOD_ADD, attr, value)) # Add an access attribute for the creator - if self.binddn and not ACCESS_ATTRIBUTE in entry : - mods.append((ldap.MOD_ADD, ACCESS_ATTRIBUTE, self.binddn)) + if self.binddn and not self.access_attribute in entry : + mods.append((ldap.MOD_ADD, self.access_attribute, self.binddn)) # Make the actual changes self.__change(parsed, mods, tags) @@ -507,7 +517,7 @@ class Database(Backend.Database): raise Backend.Error(Backend.INSUFFICIENT_ACCESS, "Access denied to delete entry: %s" % dn) mods = [] - mods.append((ldap.MOD_DELETE, REF_ATTRIBUTE, None)) + mods.append((ldap.MOD_DELETE, self.ref_attribute, None)) # Make the actual changes self.__change(parsed, mods, tags) @@ -546,7 +556,7 @@ class Database(Backend.Database): for (op, attr, value) in mods: # Process access attributes later - if attr != REF_ATTRIBUTE: + if attr != self.ref_attribute: continue if op == ldap.MOD_ADD: @@ -587,7 +597,6 @@ class Database(Backend.Database): errors = [] for (dn, (key, mod)) in keys_and_mods_by_dn.items(): try: - print dn, mod self.lookups.modify(dn, mod) except (ldap.TYPE_OR_VALUE_EXISTS, ldap.NO_SUCH_ATTRIBUTE): continue @@ -600,12 +609,12 @@ class Database(Backend.Database): # Send back errors if errors: raise Backend.Error(Backend.CONSTRAINT_VIOLATION, - "Couldn't change %s for %s" % (KEY_ATTRIBUTE, ", ".join(errors))) + "Couldn't change %s for %s" % (self.key_attribute, ", ".join(errors))) # Process other attributes now dn = ldap.dn.dn2str(parsed) for (op, attr, value) in mods: - if attr == REF_ATTRIBUTE: + if attr == self.ref_attribute: continue if op == ldap.MOD_ADD: self.storage.store(dn, attr, value) diff --git a/slapd-pivot.py b/slapd-pivot.py index 7eb1dab..d657b1b 100644 --- a/slapd-pivot.py +++ b/slapd-pivot.py @@ -1,18 +1,147 @@ #!/usr/bin/env python -import sys -import Backend, Pivot +import sys, os +import getopt, syslog +import pwd, grp + +import Backend, Pivot, Config + +SCRIPT = "slapd-pivot" +USER = None +GROUP = None + +class Log: + def __init__(self): + syslog.openlog('slapd-pivot') + def write(self, string): + string = string.encode("utf-8", "replace").strip("\n\r") + if string: + syslog.syslog(syslog.LOG_WARNING | syslog.LOG_DAEMON, s) + def flush(self): + pass + + +def failure(msg, details = None): + if details: + msg += ": " + details + print >> sys.stderr, "%s: %s" % (SCRIPT, msg) + sys.exit(1) + + +def usage(): + print >> sys.stderr, "usage: %s -f config [-d level] [-u user] [-g group]" % SCRIPT + sys.exit(2) + def run_server(): - Backend.debug = True server = Backend.Server("/tmp/pivot-slapd.sock", Pivot.Database) - try: server.serve_forever() except KeyboardInterrupt: sys.exit(0) +def drop_privileges(): + global GROUP, USER + + if GROUP: + try: + GROUP = int(GROUP) + except ValueError: + try: + GROUP = grp.getgrgid(GROUP)[2] + except KeyError: + failure("invalid group: %s" % GROUP) + os.setegid(GROUP) + + if USER: + try: + USER = int(USER) + except ValueError: + try: + USER = pwd.getpwnam(USER)[2] + except KeyError: + failure("invalid user: %s" % USER) + os.seteuid(USER) + + +def daemon(): + # do the UNIX double-fork magic, see Stevens' "Advanced + # Programming in the UNIX Environment" for details (ISBN 0201563177) + try: + pid = os.fork() + if pid > 0: + # exit first parent + sys.exit(0) + except OSError, e: + failure("couldn't fork to daemon", e.strerror) + + os.chdir(os.path.dirname(sys.argv[0])) + + # decouple from parent environment + os.setsid() + os.umask(0) + + # do second fork + try: + pid = os.fork() + if pid > 0: + # exit from second parent, print eventual PID before + # print "Daemon PID %d" % pid + open(PIDFILE,'w').write("%d"%pid) + sys.exit(0) + except OSError, e: + failure("couldn't fork to daemon", e.strerror) + + os.chdir("/") + sys.stderr = Log() + + if __name__ == '__main__': + + daemonize = True + config = "/usr/local/etc/slapd-pivot.conf" + + try: + opts, args = getopt.getopt(sys.argv[1:], 'd:f:g:u:') + except getopt.GetoptError: + usage() + + for (opt, oarg) in opts: + if opt == "-d": + try: + daemonize = False + level = int(oarg) + if level >= 4: + Backend.debug = True + except: + failure("invalid debug level", oarg) + elif opt == '-f': + config = oarg + + elif opt == '-g': + GROUP = oarg + + elif opt == '-u': + USER = oarg + + # No extra arguments + if args: + usage() + + try: + # Load up our config file + Config.load(config) + except Config.Error, ex: + failure(str(ex)) + + # Change to a user that was specified + drop_privileges() + + # Become a daemon if requested + if daemonize: + daemon() + + # And off we go run_server() -- cgit v1.2.3