diff --git a/.gitignore b/.gitignore index 559ac62..612f68b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ pushhg pushreg mypoly.py tmp +poly2.py diff --git a/README.rst b/README.rst index b6c6cfc..4c29dbc 100644 --- a/README.rst +++ b/README.rst @@ -228,12 +228,14 @@ Manager Inheritance The current polymorphic models implementation unconditionally inherits all managers from its base models (but only the -polymorphic ones). An example:: +polymorphic base models). + +An example (inheriting from MyModel above):: class MyModel2(MyModel): pass - # Managers inherited from MyModel, delivering MyModel2 objects + # Managers inherited from MyModel, delivering MyModel2 objects (including MyModel2 subclass objects) >>> MyModel2.objects.all() >>> MyModel2.ordered_objects.all() diff --git a/poly/polymorphic.py b/poly/polymorphic.py index ce2a3c6..28e33bf 100644 --- a/poly/polymorphic.py +++ b/poly/polymorphic.py @@ -743,19 +743,56 @@ def _translate_polymorphic_filter_spec(queryset_model, field_path, field_val): return _create_model_filter_Q(field_val, not_instance_of=True) elif not '___' in field_path: return None #no change - + # filter expression contains '___' (i.e. filter for polymorphic field) # => get the model class specified in the filter expression - # TODO: if app not given, just model name => try to find model in any app?? - classname, sep, pure_field_path = field_path.partition('___') - if '__' in classname: appname, sep, classname = classname.partition('__') - else: appname = queryset_model._meta.app_label - model = models.get_model(appname, classname) - assert model, 'model %s (in app %s) not found!' % (modelname, appname) - if not issubclass(model, queryset_model): - e = 'queryset filter error: "' + model.__name__ + '" is not derived from "' + queryset_model.__name__ + '"' - raise AssertionError(e) + newpath = _translate_polymorphic_field_path(queryset_model, field_path) + return (newpath, field_val) + + +def _translate_polymorphic_field_path(queryset_model, field_path): + """ + Translate a field path from keyword argument, as used for + PolymorphicQuerySet.filter()-like functions (and Q objects). + E.g.: ModelC___field3 is translated into modela__modelb__modelc__field3 + Returns: translated path + """ + classname, sep, pure_field_path = field_path.partition('___') + assert sep == '___' + + if '__' in classname: + # the user has app label prepended to class name via __ => use Django's get_model function + appname, sep, classname = classname.partition('__') + model = models.get_model(appname, classname) + assert model, 'model %s (in app %s) not found!' % (model.__name__, appname) + if not issubclass(model, queryset_model): + e = 'queryset filter error: "' + model.__name__ + '" is not derived from "' + queryset_model.__name__ + '"' + raise AssertionError(e) + + else: + # the user has only given us the class name via __ + # => select the model from the sub models of the queryset base model + + # function to collect all sub-models, this could be optimized + def add_all_sub_models(model, result): + if issubclass(model, models.Model) and model != models.Model: + # model name is occurring twice in submodel inheritance tree => Error + if model.__name__ in result and model!=result[model.__name__]: + assert model, 'model name is ambiguous: %s.%s, %s.%s!' % ( + model._meta.app_label, model.__name__, + result[model.__name__]._meta.app_label, result[model.__name__].__name__) + + result[model.__name__] = model + + for b in model.__subclasses__(): + add_all_sub_models(b, result) + + submodels = {} + add_all_sub_models(queryset_model, submodels) + model=submodels.get(classname,None) + assert model, 'model %s not found (not a subclass of %s)!' % (model.__name__, queryset_model.__name__) + # create new field path for expressions, e.g. for baseclass=ModelA, myclass=ModelC # 'modelb__modelc" is returned def _create_base_path(baseclass, myclass): @@ -770,7 +807,8 @@ def _translate_polymorphic_filter_spec(queryset_model, field_path, field_val): basepath = _create_base_path(queryset_model, model) newpath = basepath + '__' if basepath else '' newpath += pure_field_path - return (newpath, field_val) + return newpath + def _create_model_filter_Q(modellist, not_instance_of=False): """ @@ -1037,4 +1075,3 @@ class PolymorphicModel(models.Model): if f != last: out += ', ' return '<' + out + '>' -