python-mylib/mylib/ldap.py
Benjamin Renard 6bbacce38a Code cleaning
2021-05-19 19:19:57 +02:00

430 lines
18 KiB
Python

# -*- coding: utf-8 -*-
""" LDAP server connection helper """
import copy
import datetime
import logging
import dateutil.parser
import dateutil.tz
import ldap
from ldap.controls import SimplePagedResultsControl
from ldap.controls.simple import RelaxRulesControl
import ldap.modlist as modlist
import pytz
class LdapServer:
""" LDAP server connection helper """ # pylint: disable=useless-object-inheritance
uri = None
dn = None
pwd = None
v2 = None
con = 0
def __init__(self, uri, dn=None, pwd=None, v2=None, raiseOnError=False, logger=False):
self.uri = uri
self.dn = dn
self.pwd = pwd
self.raiseOnError = raiseOnError
if v2:
self.v2 = True
if logger:
self.logger = logger
else:
self.logger = logging.getLogger(__name__)
def _error(self, error, level=logging.WARNING):
if self.raiseOnError:
raise LdapServerException(error)
self.logger.log(level, error)
def connect(self):
""" Start connection to LDAP server """
if self.con == 0:
try:
con = ldap.initialize(self.uri)
if self.v2:
con.protocol_version = ldap.VERSION2 # pylint: disable=no-member
else:
con.protocol_version = ldap.VERSION3 # pylint: disable=no-member
if self.dn:
con.simple_bind_s(self.dn, self.pwd)
elif self.uri.startswith('ldapi://'):
con.sasl_interactive_bind_s("", ldap.sasl.external())
self.con = con
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error connecting and binding to LDAP server : %s' % e, logging.CRITICAL)
return False
return True
@staticmethod
def get_scope(scope):
""" Map scope parameter to python-ldap value """
if scope == 'base':
return ldap.SCOPE_BASE # pylint: disable=no-member
if scope == 'one':
return ldap.SCOPE_ONELEVEL # pylint: disable=no-member
if scope == 'sub':
return ldap.SCOPE_SUBTREE # pylint: disable=no-member
raise Exception("Unknown LDAP scope '%s'" % scope)
def search(self, basedn, filterstr=None, attrs=None, sizelimit=0, scope=None):
""" Run a search on LDAP server """
res_id = self.con.search(
basedn,
self.get_scope(scope if scope else 'sub'),
filterstr if filterstr else '(objectClass=*)',
attrs if attrs else []
)
ret = {}
c = 0
while True:
res_type, res_data = self.con.result(res_id, 0)
if res_data == [] or (sizelimit and c > sizelimit):
break
if res_type == ldap.RES_SEARCH_ENTRY: # pylint: disable=no-member
ret[res_data[0][0]] = res_data[0][1]
c += 1
return ret
def get_object(self, dn, filterstr=None, attrs=None):
""" Retrieve a LDAP object specified by its DN """
result = self.search(dn, filterstr=filterstr, scope='base', attrs=attrs)
return result[dn] if dn in result else None
def paged_search(self, basedn, filterstr, attrs, scope='sub', pagesize=500):
""" Run a paged search on LDAP server """
assert not self.v2, "Paged search is not available on LDAP version 2"
# Initialize SimplePagedResultsControl object
page_control = SimplePagedResultsControl(
True,
size=pagesize,
cookie='' # Start without cookie
)
ret = {}
pages_count = 0
self.logger.debug(
"LdapServer - Paged search with base DN '%s', filter '%s', scope '%s', pagesize=%d and attrs=%s",
basedn,
filterstr,
scope,
pagesize,
attrs
)
while True:
pages_count += 1
self.logger.debug(
"LdapServer - Paged search: request page %d with a maximum of %d objects (current total count: %d)",
pages_count,
pagesize,
len(ret)
)
try:
res_id = self.con.search_ext(
basedn,
self.get_scope(scope),
filterstr,
attrs,
serverctrls=[page_control]
)
except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error running paged search on LDAP server: %s' % e, logging.CRITICAL)
return False
try:
rtype, rdata, rmsgid, rctrls = self.con.result3(res_id) # pylint: disable=unused-variable
except ldap.LDAPError as e: # pylint: disable=no-member
self._error('LdapServer - Error pulling paged search result from LDAP server: %s' % e, logging.CRITICAL)
return False
# Detect and catch PagedResultsControl answer from rctrls
result_page_control = None
if rctrls:
for rctrl in rctrls:
if rctrl.controlType == SimplePagedResultsControl.controlType:
result_page_control = rctrl
break
# If PagedResultsControl answer not detected, paged serach
if not result_page_control:
self._error('LdapServer - Server ignores RFC2696 control, paged search can not works', logging.CRITICAL)
return False
# Store results of this page
for obj_dn, obj_attrs in rdata:
ret[obj_dn] = obj_attrs
# If no cookie returned, we are done
if not result_page_control.cookie:
break
# Otherwise, set cookie for the next search
page_control.cookie = result_page_control.cookie
self.logger.debug("LdapServer - Paged search end: %d object(s) retreived in %d page(s) of %d object(s)", len(ret), pages_count, pagesize)
return ret
def add_object(self, dn, attrs):
""" Add an object in LDAP directory """
ldif = modlist.addModlist(attrs)
try:
self.logger.debug("LdapServer - Add %s", dn)
self.con.add_s(dn, ldif)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error adding %s : %s" % (dn, e), logging.ERROR)
return False
def update_object(self, dn, old, new, ignore_attrs=None, relax=False):
""" Update an object in LDAP directory """
assert not relax or not self.v2, "Relax modification is not available on LDAP version 2"
ldif = modlist.modifyModlist(
old, new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if ldif == []:
return True
try:
if relax:
self.con.modify_ext_s(dn, ldif, serverctrls=[RelaxRulesControl()])
else:
self.con.modify_s(dn, ldif)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error updating %s : %s\nOld : %s\nNew : %s" % (dn, e, old, new), logging.ERROR)
return False
@staticmethod
def update_need(old, new, ignore_attrs=None):
""" Check if an update is need on a LDAP object based on its old and new attributes values """
ldif = modlist.modifyModlist(
old, new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
if ldif == []:
return False
return True
@staticmethod
def get_changes(old, new, ignore_attrs=None):
""" Retrieve changes (as modlist) on an object based on its old and new attributes values """
return modlist.modifyModlist(
old, new,
ignore_attr_types=ignore_attrs if ignore_attrs else []
)
@staticmethod
def format_changes(old, new, ignore_attrs=None, prefix=None):
""" Format changes (modlist) on an object based on its old and new attributes values to display/log it """
msg = []
for (op, attr, val) in modlist.modifyModlist(old, new, ignore_attr_types=ignore_attrs if ignore_attrs else []):
if op == ldap.MOD_ADD: # pylint: disable=no-member
op = 'ADD'
elif op == ldap.MOD_DELETE: # pylint: disable=no-member
op = 'DELETE'
elif op == ldap.MOD_REPLACE: # pylint: disable=no-member
op = 'REPLACE'
else:
op = 'UNKNOWN (=%s)' % op
if val is None and op == 'DELETE':
msg.append('%s - %s %s' % (prefix if prefix else '', op, attr))
else:
msg.append('%s - %s %s: %s' % (prefix, op, attr, val))
return '\n'.join(msg)
def rename_object(self, dn, new_rdn, new_sup=None, delete_old=True):
""" Rename an object in LDAP directory """
# If new_rdn is a complete DN, split new RDN and new superior DN
if len(new_rdn.split(',')) > 1:
self.logger.debug(
"LdapServer - Rename with a full new DN detected (%s): split new RDN and new superior DN",
new_rdn
)
assert new_sup is None, "You can't provide a complete DN as new_rdn and also provide new_sup parameter"
new_dn_parts = new_rdn.split(',')
new_sup = ','.join(new_dn_parts[1:])
new_rdn = new_dn_parts[0]
try:
self.logger.debug(
"LdapServer - Rename %s in %s (new superior: %s, delete old: %s)",
dn,
new_rdn,
"same" if new_sup is None else new_sup,
delete_old
)
self.con.rename_s(dn, new_rdn, newsuperior=new_sup, delold=delete_old)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error(
"LdapServer - Error renaming %s in %s (new superior: %s, delete old: %s): %s" % (
dn,
new_rdn,
"same" if new_sup is None else new_sup,
delete_old,
e
),
logging.ERROR
)
return False
def drop_object(self, dn):
""" Drop an object in LDAP directory """
try:
self.logger.debug("LdapServer - Delete %s", dn)
self.con.delete_s(dn)
return True
except ldap.LDAPError as e: # pylint: disable=no-member
self._error("LdapServer - Error deleting %s : %s" % (dn, e), logging.ERROR)
return False
@staticmethod
def get_dn(obj):
""" Retreive an on object DN from its entry in LDAP search result """
return obj[0][0]
@staticmethod
def get_attr(obj, attr, all=None, default=None):
""" Retreive an on object attribute value(s) from the object entry in LDAP search result """
if attr not in obj:
for k in obj:
if k.lower() == attr.lower():
attr = k
break
if all is not None:
if attr in obj:
return obj[attr]
return default or []
if attr in obj:
return obj[attr][0]
return default
class LdapServerException(BaseException):
""" Generic exception raised by LdapServer """
def __init__(self, msg):
BaseException.__init__(self, msg)
#
# LDAP date string helpers
#
def parse_datetime(value, to_timezone=None, default_timezone=None, naive=None):
"""
Convert LDAP date string to datetime.datetime object
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : return naive datetime object (without timezone conversion from LDAP)
"""
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone)
assert default_timezone is None or isinstance(default_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'default_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(default_timezone)
date = dateutil.parser.parse(value, dayfirst=False)
if not date.tzinfo:
if naive:
return date
if not default_timezone:
default_timezone = pytz.utc
elif default_timezone == 'local':
default_timezone = dateutil.tz.tzlocal()
elif isinstance(default_timezone, str):
default_timezone = pytz.timezone(default_timezone)
if isinstance(default_timezone, pytz.tzinfo.DstTzInfo):
date = default_timezone.localize(date)
elif isinstance(default_timezone, datetime.tzinfo):
date = date.replace(tzinfo=default_timezone)
else:
raise Exception("It's not supposed to happen!")
elif naive:
return date.replace(tzinfo=None)
if to_timezone:
if to_timezone == 'local':
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
return date.astimezone(to_timezone)
return date
def parse_date(value, to_timezone=None, default_timezone=None, naive=None):
"""
Convert LDAP date string to datetime.date object
:param value: The LDAP date string to convert
:param to_timezone: If specified, the return datetime will be converted to this
specific timezone (optional, default : timezone of the LDAP date string)
:param default_timezone: The timezone used if LDAP date string does not specified
the timezone (optional, default : server local timezone)
:param naive: Use naive datetime : do not handle timezone conversion from LDAP
"""
return parse_datetime(value, to_timezone, default_timezone, naive).date()
def format_datetime(value, from_timezone=None, to_timezone=None, naive=None):
"""
Convert datetime.datetime object to LDAP date string
:param value: The datetime.datetime object to convert
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : datetime store as UTC in LDAP (without conversion)
"""
assert isinstance(value, datetime.datetime), 'First parameter must be an datetime.datetime object (not %s)' % type(value)
assert from_timezone is None or isinstance(from_timezone, (datetime.tzinfo, pytz.tzinfo.DstTzInfo, str)), 'from_timezone parameter must be None, a string, a pytz.tzinfo.DstTzInfo or a datetime.tzinfo object (not %s)' % type(from_timezone)
assert to_timezone is None or isinstance(to_timezone, (datetime.tzinfo, str)), 'to_timezone must be None, a datetime.tzinfo object or a string (not %s)' % type(to_timezone)
if not value.tzinfo and not naive:
if not from_timezone or from_timezone == 'local':
from_timezone = dateutil.tz.tzlocal()
elif isinstance(from_timezone, str):
from_timezone = pytz.timezone(from_timezone)
if isinstance(from_timezone, pytz.tzinfo.DstTzInfo):
from_value = from_timezone.localize(value)
elif isinstance(from_timezone, datetime.tzinfo):
from_value = value.replace(tzinfo=from_timezone)
else:
raise Exception("It's not supposed to happen!")
elif naive:
from_value = value.replace(tzinfo=pytz.utc)
else:
from_value = copy.deepcopy(value)
if not to_timezone:
to_timezone = pytz.utc
elif to_timezone == 'local':
to_timezone = dateutil.tz.tzlocal()
elif isinstance(to_timezone, str):
to_timezone = pytz.timezone(to_timezone)
to_value = from_value.astimezone(to_timezone) if not naive else from_value
datestring = to_value.strftime('%Y%m%d%H%M%S%z')
if datestring.endswith('+0000'):
datestring = datestring.replace('+0000', 'Z')
return datestring
def format_date(value, from_timezone=None, to_timezone=None, naive=None):
"""
Convert datetime.date object to LDAP date string
:param value: The datetime.date object to convert
:param from_timezone: The timezone used if datetime.datetime object is naive (no tzinfo)
(optional, default : server local timezone)
:param to_timezone: The timezone used in LDAP (optional, default : UTC)
:param naive: Use naive datetime : do not handle timezone conversion before formating
and return datetime as UTC (because LDAP required a timezone)
"""
assert isinstance(value, datetime.date), 'First parameter must be an datetime.date object (not %s)' % type(value)
return format_datetime(datetime.datetime.combine(value, datetime.datetime.min.time()), from_timezone, to_timezone, naive)