Python: Principles of Object Oriented Programming in Python#

What is “object”?#

Object-Oriented Programming (OOP) is a programming paradigm that involves the use of classes and objects to create functional programs. OOP is widely used in data science for building complex and scalable applications. Python is an object-oriented programming language, which means it fully supports OOP (as is Java, C++, rust, etc).

h:200 A concept

h:200 an “object” aka triangle

attributes (properties)

  • 3 edges

  • edgecolor

  • facecolor …

methods (functions)

  • get_edges

  • set_color

  • move_to …

\(\displaystyle{\textrm{concept described by class Triangle} = \underbrace{\textrm{attributes + methods}}_{\textrm{encapsulation}}}\)

OOP is particularly useful for data science because it allows us to create complex and scalable applications that can handle large amounts of data. OOP provides a modular and reusable code structure that can save time and effort when working on complex projects. Additionally, many popular data science libraries and tools, such as NumPy and Pandas, are built using OOP principles.

Important Vocabulary: object vs. class#

It is important to understand the difference between Class and Object

A class is the blueprint or template used to create different objects. A class is a data type that can contain data members (variables) and functions. An object is an instance of a class, for instance an actual pen in constrast with the dictionary definition. The object is created from the class blueprint and is a copy of all of its members.

Important

  • a class is a definition, aka blueprint or template from which objects are created

  • an object is an entity instanciated from a class.

import pylab as plt
r = plt.plot(range(10))
print(type(r), type(r[0]))
r
<class 'list'> <class 'matplotlib.lines.Line2D'>
[<matplotlib.lines.Line2D at 0x7f28db35cf70>]
../../_images/208b5c0762b2e2322cd07c85ea6c3114219a33bd05508415e863e4e942b1dd6d.png

The above example the result from plt.plot is an object of type list (<class 'list'>), which itself contains an object of type matplotlib.lines.Line2D. The actual objects are what we see on this plot, but of course they will vary between plots.

In python, everything is an object#

In Python, everything is an object, and every object belongs to a class. To create an object of a certain class, we use the class name followed by parentheses.

s = "some text"
s1 = s.strip()  # remove leading/trailing spaces
su = s1.upper() # make uppercase
print(su)
SOME TEXT

Here, we created a variable s, this variable is a reference to a string, i.e. an object of type str (Python built-in class). This class also defines common methods that we access with the . (<dot>) operator.

Matplotlib is a good example of library that would be very hard to write and use without OOP.

import pylab as plt
# capture all results in r
r = []
r.append(plt.plot(range(10)))
r.append(plt.xlabel('measured'))
r.append(plt.ylabel('calculated'))
r.append(plt.title('Measured vs. Calculated'))
r.append(plt.grid(True))

# update existing objects
ax = plt.gca()       # gca == get current axis
line = ax.lines[0]
line.set_marker('o')
plt.setp(line, color='g');   # set properties

r
[[<matplotlib.lines.Line2D at 0x7f28d924f820>],
 Text(0.5, 0, 'measured'),
 Text(0, 0.5, 'calculated'),
 Text(0.5, 1.0, 'Measured vs. Calculated'),
 None]
../../_images/cbef514b3600147803de69b3ef39503ccb5db2f0656b5d521a6036e57d06196b.png

The above code example makes a simple plot. We can see that we created a Line2D and multiple Text objects, that we alter later on.

Python Modules are objects too! These are of type class module

import numpy as np
type(np), type(np.random), type(np.random.random)
(module, module, builtin_function_or_method)

As well as functions, code source, even the running environment.

How do we define our own classes / objects?#

Let’s make a class “Vector”#

Class Vector:

  • attrs: x, y, z

  • Methods: addition, scalar product, projection

class Vector:
    """ A 3D vector. """
    def __init__(self, x, y, z):
        """ Initialize a vector.

        :param x: first component of the vector
        :type x: float
        :param y: second component of the vector
        :type y: float
        :param z: third component of the vector
        :type z: float
        """
        self.x, self.y, self.z = x, y, z

Any time an instance of a class is created, the __init__() method is called. This method is called constructor. Python will always provide a defaul minimal implementation when not explicitly provided.

In the example above, the constructor stores the x, y, z values inside the object.

The first argument of the __init__() method is the self variable, which correspond to the actual object itself. In Python, you can recognize any class method to it definition with explicitly having this first argument. However, the name self is a convention but not imposed. In other languages, it is often named this (e.g. C++/Java) and could be implicitly provided.

Exercise: Complete the Vector class by implementing add, dot and cross methods.

  • make a correct documentation (“docstrings”) which gives your assumptions.

You can use the following snippet as a starting point:

class Vector:
    def __init__(self, x, y, z):
        """ Constructor """
        self.x, self.y, self.z = x, y, z
    def add(self, other):
        """ add 2 vectors """
        pass
    def dot(self, other):
        """ product of 2 vectors """
        pass
    def cross(self, other):
        """ cross product of 2 vectors """
        pass
Hide code cell source
from numbers import Number


class Vector:
    """ A 3D vector. """
    def __init__(self, x, y, z):
        """ Initialize a vector.

        :param x: first component of the vector
        :type x: float
        :param y: second component of the vector
        :type y: float
        :param z: third component of the vector
        :type z: float
        """
        self.x, self.y, self.z = x, y, z

    def add(self, other):
        """ Add two vectors or a scalar to a vector.

        :param other: the vector or scalar value to add
        :type other: Vector or Number
        :return: the sum of the two vectors
        :rtype: Vector
        """
        if isinstance(other, Vector):
            return Vector(self.x + other.x, self.y + other.y, self.z + other.z)
        elif isinstance(other, Number):
                    return Vector(self.x + other, self.y + other, self.z + other)
        raise TypeError("Cannot add {} to {}".format(type(other), type(self)))

    def dot(self, other):
         """ multiply two vectors or a scalar to a vector.

         :param other: the vector or scalar value
         :type other: Vector or Number
         :return: the product of the two vectors
         :rtype: Vector
         """
         if isinstance(other, Vector):
             return sum([self.x * other.x, self.y * other.y, self.z * other.z])
         elif isinstance(other, Number):
             return Vector(self.x * other, self.y * other, self.z * other)
         raise TypeError("Cannot multiply {} to {}".format(type(other), type(self)))

    def cross(self, other):
         """ multiply two vectors or a scalar to a vector.

         :param other: the vector or scalar value
         :type other: Vector or Number
         :return: the product of the two vectors
         :rtype: Vector
         """
         if isinstance(other, Vector):
             return Vector(self.y * other.z - self.z * other.y,
                           self.z * other.x - self.x * other.z,
                           self.x * other.y - self.y * other.x)
         raise TypeError("Cannot matmultiply {} to {}".format(type(other), type(self)))
a = Vector(1,0,0)
b = Vector(0,1,0)
c = a.add(b)
print('  add: ', a.add(b))
print('  dot: ', a.dot(b))
print('cross: ', a.cross(b))
print(f'Vector({c.x}, {c.y}, {c.z})')
  add:  <__main__.Vector object at 0x7f28d9217f40>
  dot:  0
cross:  <__main__.Vector object at 0x7f28d9217f40>
Vector(1, 1, 0)

What’s wrong with print?

  • Print is equivalent to str(a) or a.__str__()

  • we need to define __str__()

Exercise: Define a `str`` method that returns a string representation of your object.

Hide code cell source
from numbers import Number


class Vector:
    """ A 3D vector. """
    def __init__(self, x, y, z):
        """ Initialize a vector.

        :param x: first component of the vector
        :type x: float
        :param y: second component of the vector
        :type y: float
        :param z: third component of the vector
        :type z: float
        """
        self.x, self.y, self.z = x, y, z

    def add(self, other):
        """ Add two vectors or a scalar to a vector.

        :param other: the vector or scalar value to add
        :type other: Vector or Number
        :return: the sum of the two vectors
        :rtype: Vector
        """
        if isinstance(other, Vector):
            return Vector(self.x + other.x, self.y + other.y, self.z + other.z)
        elif isinstance(other, Number):
                    return Vector(self.x + other, self.y + other, self.z + other)
        raise TypeError("Cannot add {} to {}".format(type(other), type(self)))

    def dot(self, other):
         """ multiply two vectors or a scalar to a vector.

         :param other: the vector or scalar value
         :type other: Vector or Number
         :return: the product of the two vectors
         :rtype: Vector
         """
         if isinstance(other, Vector):
             return sum((self.x * other.x, self.y * other.y, self.z * other.z))
         elif isinstance(other, Number):
             return Vector(self.x * other, self.y * other, self.z * other)
         raise TypeError("Cannot multiply {} to {}".format(type(other), type(self)))

    def cross(self, other):
         """ multiply two vectors or a scalar to a vector.

         :param other: the vector or scalar value
         :type other: Vector or Number
         :return: the product of the two vectors
         :rtype: Vector
         """
         if isinstance(other, Vector):
             return Vector(self.y * other.z - self.z * other.y,
                           self.z * other.x - self.x * other.z,
                           self.x * other.y - self.y * other.x)
         raise TypeError("Cannot matmultiply {} to {}".format(type(other), type(self)))

    def __str__(self):
         return "Vector({}, {}, {})".format(self.x, self.y, self.z)
a = Vector(1,0,0)
b = Vector(0,1,0)
c = a.add(b)
print('  add: ', a.add(b))
print('  dot: ', a.dot(b))
print('cross: ', a.cross(b))
  add:  Vector(1, 1, 0)
  dot:  0
cross:  Vector(0, 0, 1)

Exercise: Let’s check the Pythagora’s theorem, i.e. \(x^2 + y^2 + 2\cdot x \cdot y = (x + y) ^ 2\)

x = Vector(1, 2, 3)
y = Vector(4, 5, 6)

The following code does the comparison:

Hide code cell source
x.dot(x) + y.dot(y) + 2 * x.dot(y) == x.add(y).dot(x.add(y))
True

Rapidly with the various operations, we can imagine that the readability of the code becomes difficult. But OOP provides a solution: define operations such as + and * and use them in a class.

Overloading operators#

In Python, operators correspond to special method names such as __str__ to provide a string representation of the object.

Mathematical operators are defined by the following methods: __add__, __sub__, __mul__, __matmul__, __div__, __pow__, __xor__and many others

See python datamodel

Exercise: complete the Vector class such as Vector can offer a more readable version of \(x^2 + y^2 + 2\cdot x \cdot y = (x + y) ^ 2\) as x ** 2 + y ** 2 + 2 * x * y == (x + y) ** 2

  • manages a string representation

  • handle common operators +,

  • Implement a dot product solution based on *

  • Implement a power method based on **

Hide code cell source
from numbers import Number


class Vector:
    """ A 3D vector. """
    def __init__(self, x, y, z):
        """ Initialize a vector.

        :param x: first component of the vector
        :type x: float
        :param y: second component of the vector
        :type y: float
        :param z: third component of the vector
        :type z: float
        """
        self.x, self.y, self.z = x, y, z

    def __add__(self, other):
        """ Add two vectors or a scalar to a vector.

        :param other: the vector or scalar value to add
        :type other: Vector or Number
        :return: the sum of the two vectors
        :rtype: Vector
        """
        if isinstance(other, Vector):
            return Vector(self.x + other.x, self.y + other.y, self.z + other.z)
        elif isinstance(other, Number):
                    return Vector(self.x + other, self.y + other, self.z + other)
        raise TypeError("Cannot add {} to {}".format(type(other), type(self)))

    def __sub__(self, other):
        """ Subtract two vectors or a scalar to a vector.

        :param other: the vector or scalar value to subtract
        :type other: Vector or Number
        :return: the difference of the two vectors
        """
        return self + (-other)

    def __mul__(self, other):
         """ multiply two vectors or a scalar to a vector.

         :param other: the vector or scalar value
         :type other: Vector or Number
         :return: the product of the two vectors
         :rtype: Vector
         """
         if isinstance(other, Vector):
             return sum((self.x * other.x, self.y * other.y, self.z * other.z))
         elif isinstance(other, Number):
             return Vector(self.x * other, self.y * other, self.z * other)
         raise TypeError("Cannot multiply {} to {}".format(type(other), type(self)))

    def __pow__(self, other):
         """ power a vector or a scalar to a vector. """
         if isinstance(other, Number):
            val = Vector(self.x, self.y, self.z)
            for i in range(1, other):
                val = val * val
            return val
         raise TypeError("Cannot power {} to {}".format(type(other), type(self)))

    def __str__(self):
         return "Vector({}, {}, {})".format(self.x, self.y, self.z)
x = Vector(1, 2, 3)
y = Vector(4, 5, 6)
print(x ** 2 + y ** 2 + x * y * 2 == (x + y) ** 2)
True

A general version of the Vector class to many dimensions is for instance a NumPy Array.

Functions are objects too! They implement __call__#

exercise: Create a “function” interp of two sequences x, y that - keep in mind the input data - and when called with a sequence xn, interpolates the data yn = y(xn)

ℹ️ There are two approaches possible (using a class or only functions)

💡 Don’t code the interpolation, use np.interp instead.

Hide code cell source
class Interp:
    """ Simple interpolation class. """
    def __init__(self, x, y):
        """ x, y are the coordinates of the point to interpolate """
        self.xy = (x, y)
    def __call__(self, xn):
        """ return interpolated values at x"""
        x, y = self.xy
        return np.interp(xn, x, y)

def interp_fn(x, y):
    """ An interpolation function. """
    def interpolate(xn):
        """ return interpolated values at x"""
        return np.interp(xn, x, y)
    return interpolate
import numpy as np
import pylab as plt
x = np.arange(10)
y = np.hstack([x[: 5] ** 2, 25 - 5 * x[5:]])
interp = Interp(x, y)
interp = interp_fn(x, y)
xn = np.linspace(0, 10, 20)
plt.plot(x, y, 'ko-', lw=2, label='True')
plt.plot(xn, interp(xn), 'o:', lw=3, color='C1', mfc='None', label='Interpolated')
plt.legend();
../../_images/d95391abc32e402e98b1d62f7522093c51d7bb26620001344e8081a4859e5ed5.png

In the above implementation, we show both solutions to the problem, i.e. and explicit class definition and a function definition.

The latter corresponds to a function that returns a function.

Class Attributes vs Object Attributes#

In Python, classes and objects store their associated data with attributes. While both class attributes and object attributes serve this purpose, there are important differences between the two.

class Car:
    color = "red"   # class attribute

car1 = Car()
car2 = Car()

print(f'car1.color = {car1.color: <6s}, car2.color = {car2.color: <6s}')
Hide code cell output
car1.color = red   , car2.color = red   
Car.color = "blue"
print(f'car1.color = {car1.color: <6s}, car2.color = {car2.color: <6s}')
Hide code cell output
car1.color = blue  , car2.color = blue  

Class Attributes#

A class attribute belongs to the class and is shared among all instances of the class. This means that when a value for a class attribute is changed, that change is reflected in all instances of the class.

In the example above, we have defined a Car class with a class attribute color. Both car1 and car2 objects are instances of the class Car. As expected, both instances inherit the class attribute color.

Later in the example, the color class attribute is changed to "blue". Notice how both car1 and car2 instances reflect this change.

Object Attributes#

An object attribute belongs to an instance of a class. This means that each instance of a class can have its own unique value for an object attribute.

car1.brand = 'bmw'
car2.brand = 'audi'
print(f'car1.brand = {car1.brand: <6s}, car2.brand = {car2.brand: <6s}')
car1.brand = bmw   , car2.brand = audi  

In the example above, both car1 and car2 objects are instances of the class Car. Notice how each instance has its own unique value for brand.

Warning

Avoid setting attributes without defining them in a constructor. This makes codes hard to understand and maintain

Example:

class Car:
    def __init__(self, color):
        self.color = color
        self.brand = None

A subtle pitfall to keep in mind is that in Python, instances take precedence on the class.

car2.color = 'black'
print(f'car1.color = {car1.color: <6s}, car2.color = {car2.color: <6s}')
Hide code cell output
car1.color = blue  , car2.color = black 

This example shows that we can override an attribute in the instance (i.e. object) without affecting the class attribute. (This behavior is not in all languages.)

What are *args and **kwargs?#

*args and **kwargs are special syntaxes in Python that allow a function to accept a variable number of arguments.

*args syntax allows a function to accept any number of positional arguments, which are wrapped up in a tuple. This means that you can call the function with any number of arguments, and they will all be packed into a single parameter of the function.

**kwargs syntax allows a function to accept any number of keyword arguments, which are wrapped up in a dictionary. This means that you can call the function with any number of named arguments, and they will all be packed into a single parameter of the function.

The use of *args and **kwargs provides flexibility and allows the functions to be more generic and re-usable.

Here is an example of how to use *args and **kwargs in Python:

def my_function(*args, **kwargs):
    for arg in args:
        print(arg)
    for key, value in kwargs.items():
        print(key, value)

my_function(1, 2, 3, name='John', age=30)

In the above example, we defined a function my_function that accepts a variable number of arguments - any number of positional arguments using *args and any number of keyword arguments using **kwargs. Inside the function, we looped through all the positional arguments and printed them, and we looped through all the keyword arguments and printed their key-value pairs.

When calling the function, we passed 3 positional arguments and 2 keyword arguments. All the positional arguments were collected in a tuple called args, while all the keyword arguments were collected in a dictionary called kwargs.

Class Inheritance: Establish relationships between objects.#

Of course, a language feature would not be worthy of the name “class” without supporting inheritance.

Class inheritance is a fundamental feature of object-oriented programming, and it allows us to create a hierarchy of related classes. In Python, a class can inherit attributes and methods from a parent class to create a new child class that has all the features of the parent class, plus some additional features specific to the child class.

To create a child class that inherits from a parent class, we use the syntax class ChildClass(ParentClass):. The child class inherits all the attributes and methods of the parent class and can also override or add its own attributes and methods.

The inheritance relationship between classes is often described as an “is-a” relationship, meaning that the child class is a type of the parent class. For example, if we have a class called Animal``, we could create a child class called Cat that inherits from `Animal, because a cat is a type of animal.

In Python, we can access the attributes and methods of the parent class using the `super()`` function. This allows us to customize the behavior of the child class while still retaining the functionality of the parent class.

Overall, class inheritance is a powerful tool in Python that allows us to create complex class hierarchies with minimal code repetition while still maintaining a high level of code reusability and modularity.

The syntax for a derived class definition looks like this:

class ParentClass:
    pass

class ChildClass(ParentClass):  
    pass

Exercise: coding Integrators#

  1. Create a function for the function \(f(x) = sin(x)\)

  2. Below is a template of the parent `Integrator`` class. This class will have all the common functionality to integrate a function \(f\) in an interval \([a,b]\) with \(n\) intermediate points. However, the method of integration is not implemented.

class Integrator:
    def __init__(self, a, b, n):
        self.a = a
        self.b = b
        self.n = n

    def integrate(self, f):
        raise NotImplementedError(
            "Integration not implemented in ",
            self.__class__.__name__)
  1. Create two derived classes: TrapzoidalRule and MidpointRule where the integration method is trapezoidal and midpoint, respectively. This should be done by only overriding the integrate function.

  2. Test your classes.

A Solution

Hide code cell source
class TrapezoidalRule(Integrator):
    def integrate(self, f):
        h = (self.b - self.a) / self.n
        summation = sum([f(self.a + i*h) for i in range(1, self.n)])
        integral = h / 2 * (f(self.a) + f(self.b) + 2 * summation)
        return integral

class MidpointRule(Integrator):
    def integrate(self, f):
        h = (self.b - self.a) / self.n
        summation = sum([f(self.a + (i+0.5)*h) for i in range(self.n)])
        integral = h * summation
        return integral
a, b = 0, 2
n = 100
trapezoid = TrapezoidalRule(a, b, n)
midpoint = MidpointRule(a, b, n)

print("Trapezoidal rule approximation:", trapezoid.integrate(np.sin))
print("Midpoint rule approximation:", midpoint.integrate(np.sin))
print("True integration: ", -np.cos(b) + np.cos(a))
Trapezoidal rule approximation: 1.4160996313378889
Midpoint rule approximation: 1.416170439269783
True integration:  1.4161468365471424

Tip

TODO: add about super()

When to use classes or not?#

When it comes to deciding whether to write a class in Python, it depends on the specific needs of your project. In general, classes can be useful when you need to organize and structure your code in a way that makes it more modular, reusable, and maintainable.

When to use classes#

  • When working on larger, more complex projects: As your project grows, it may become more difficult to manage all the functions and variables in a single script. Using classes can help organize the code and make it easier to maintain and modify.

  • When you need to define a custom data structure: Classes are a useful tool for creating custom data structures that can help you better manage and analyze your data.

  • When you need to create multiple instances of an object with shared attributes and functions: Classes can be used to define objects with shared attributes and functions that can be reused across different parts of your project.

  • When you need to preserve an API definition: Classes can be used to preserve methods and properties of objects that are similar.

Evidence for writing a class#

  • the same set of parameters is used over and over as arguments of many functions. These can be methods of the container of the variables.

  • many copies of code with minor differences (e.g. function applied to arguments of different types). Often these can be converted to functions of classes with a specific method processing the current type.

  • code that attempt to provide standard API, e.g. sklearn.Base.fit, sklearn.Base.predict

When not to use classes#

  • When working on small projects with simple requirements: If your project is small and only requires a few functions or variables, it may not be necessary to use classes.

  • When working with simple data structures or structures that are already covered by builtin classes: If your project only requires simple data structures such as lists or dictionaries, using classes may be unnecessary and even add unnecessary complexity to your code.

  • When you don’t have experience working with classes: If you’re not familiar with working with classes, it may be more efficient to stick with simple functions and data structures until you have a better understanding of object-oriented programming.

Evidence of overdoing classes#

Some examples of overdoing classes are:

  • a class with no content or only a __init__. This kind of class should be a structure such as a dictionary or a NamedTuple.

  • a class with an __init__ and a single method. This kind of class should be a function.

  • a class with many memory expensive attributes. It is rarely good to have large containers. It is often better to define smaller objects even if this generates more variables to track.

  • a class with all functions of your project. One do it all is not recommended and often corresponds to a series of functions.

Tip

  1. Classes are useful tools.

  2. Simple > complex > complicated.

  3. Practicality beats purity.

  4. Do not reinvent the wheel every time. When tools you already work: use & adapt them.

  5. Readability counts!

From the Zen of python

Jack Diederich describes cases of over-engineering codes with classes in his talk Stop writing classes by Jack Diederich (30 min video).

extracts from zen-of-python

Fig. 6 Credit: Tim Peters, Zen of python#

Exercise: stellar mass function#

This section is an exercise to complete a class to represent the Salpeter mass function.

Mathematical backgroud#

An IMF is a probability density distribution function: \(\displaystyle{p(m | IMF) = \frac{1}{Z} \frac{dN}{dm}},\)

where its integral is unity (by definition) gives use \(\displaystyle{Z = \int_{m_{min}}^{m_{max}} \frac{dN}{dm} dm}.\)

The function is continuous.

💡 The first task of the contructor will be to create a continous function from the broken power-law definition and to normalize it properly.

⚠️ when indexes of the different power-laws are in units of \(dN/dm\), \(-2.35\) corresponds to a Salpeter IMF. But, sometimes you could find \(-1.35\), which corresponds to an index defined in terms of \(dN/dlog(m)\).

From this, the fractional number of stars within an interval \([m_1, m_2]\) is the integral: \(\displaystyle{n_\star = \frac{1}{Z} \int_{m_{1}}^{m_{2}} \frac{dN}{dm} dm},\)

and the mass is expectation of the distribution, \(\displaystyle{m_\star = \frac{1}{Z} \int_{m_{1}}^{m_{2}} m \frac{dN}{dm} dm}.\)

The average mass over this interval is then \(\displaystyle{\overline{m}_\star = \frac{m_\star}{n_\star}}\)

Complete the Salpeter Class#

class Salpeter:
    def __init__(self, range = [0.1, 120.]):
        self.intervals = range
        self.slope = -2.35

    def get_mstar(self, xmin, xmax):
        """Get the enclosed mass over a given mass range. """
        pass

    def get_nstar(self, xmin, xmax):
        """Get the enclosed dN over a given mass range """
        pass

    def get_avg_mstar(self, xmin, xmax):
        """ get the avg mass over a given range """
        pass

    def get_value(self, x):
        """ returns the value of the normalized IMF at a given mass m """
        pass
    
    def __call__(self, x):
        return self.get_value(x)

💡 Don’t hesitate to define other methods to avoid code duplication or improve readability.

For unit tests, you can use the following values (assuming range = [0.1, 120.]):

  • norm = 16.582

  • average mass = 0.354 \(M_\odot\)

Hide code cell source
import numpy as np

class Salpeter:
    """ Salpeter IMF definition """
    def __init__(self, range = [0.1, 120.]):
        self.intervals = np.array(range)
        self.slope = -2.35
        self.name = "Salpeter"

        # good practice to set which variables will be used
        # they do not have to take a value here.
        self._coefficients = None
        self._norm = None
        self._mass_norm = None
        self._avg_mass = None

        self._setup()

    def _setup(self):
        """ precalculate internal values

        It's a good practice for clarity to have a separate method from the constructor.
        """
        # the first step is to build the functional form of the IMF:
        # i.e., make a continuous function and compute the normalization
        n_intervals = len(self.intervals) - 1
        self._coefficients = np.ones(n_intervals)
        self._norm = 1.

        # normalize
        # depends on the adpoted definition of the IMF indexes:
        # either dN/dM or dN/dlog(M). In this example we consider that indexes
        # are given in units of dN/dM.
        xmin, xmax = np.min(self.intervals), np.max(self.intervals)
        self._norm = self.get_nstar(xmin, xmax)
        self._mass_norm = self.get_mstar(xmin, xmax)
        # Compute the average mass
        self._avg_mass = self.get_avg_mstar(xmin, xmax)

    def _moment(self, xmin, xmax, order=0):
        """ compute the moment of the IMF: m**order * dn / dm """
        beta = self.slope + 1 + order
        if abs(beta) < 1e-30:
            return np.log(xmax / xmin)

        return (xmax ** (beta) - xmin ** (beta)) / (beta) / self._norm

    def get_nstar(self, xmin, xmax):
        """ Compute the number of stars """
        return self._moment(xmin, xmax, 0)

    def get_mstar(self, xmin, xmax):
        """ Compute the number of stars """
        return self._moment(xmin, xmax, 1)

    def get_avg_mstar(self, xmin, xmax):
        """ Compute the average mass """
        return (self.get_mstar(xmin, xmax) / self.get_nstar(xmin, xmax))

    def get_value(self, x):
        """ returns the value of the normalized IMF at a given mass m:
        """
        return x ** self.slope / self._norm

    def __call__(self, x):
        return self.get_value(self, x)
s = Salpeter()

print('s._norm ~ 16.582 |', np.allclose(s._norm, 16.582, atol=1e-4))
print('s.get_nstar(0.1, 120) = 1 |', np.allclose(s.get_nstar(0.1, 120), 1., atol=1e-4))
print('s.get_avg_mstar(0.1, 120) ~ 0.354 |', np.allclose(s.get_avg_mstar(0.1, 120), 0.354, atol=1e-3))
s._norm ~ 16.582 | True
s.get_nstar(0.1, 120) = 1 | True
s.get_avg_mstar(0.1, 120) ~ 0.354 | True

What do we need to change to define a Kroupa IMF which is defined as a piecewise powerlaw?

\begin{aligned} \forall m_i \in I_k, dN/dlog(m_i) &= \frac{1}{Z} m_i ^ {\alpha{_k}},\ {\rm intervals\ } I_n &\in [0.01, 0.08, 0.5, 1., 150.]\ \alpha_k &\in [0.7, -0.3, -1.3, -1.3] \end{aligned}

This requires to rewrite the code with sequences of slopes and generalizing integrals.

  • Generalize the previous Salpeter class to an IMF class and define Salpeter and Kroupa to inherit from it.

class Salpeter(IMF):
    """ Salpeter IMF """
    def __init__(self):
        super().__init__(slopes=[-2.35],
                         intervals = [0.1, 120.],
                         name="Salpeter")

class Kroupa2001(IMF):
    """ Kroupa IMF """
    def __init__(self):
        super().__init__(slopes=[-0.3, -1.3, -2.3, -2.3],
                         intervals = [0.01, 0.08, 0.5, 1., 150.],
                         name="Kroupa 2001")

class Kennicutt(IMF):
    """ Kennicutt (1983) doi:10.1086/161261"""
    def __init__(self):
        super().__init__(slopes=[-1.4, -2.5],
                            intervals = [0.1, 1., 120.],
                            name="Kennicutt 1983")
Hide code cell source
from typing import Any
import numpy as np

class IMF:
    """ IMF definition

    default corresponds to Salpeter
    """
    def __init__(self,
                 slopes=[-2.35],
                 intervals = [0.1, 120.],
                 name="IMF"):
        self.intervals = np.array(intervals)
        self.slopes = np.array(slopes)
        self.name = name

        # good practice to set which variables will be used
        # they do not have to take a value here.
        self._n_intervals = len(self.intervals) - 1
        self._coefficents = None
        self._norm = None
        self._mass_norm = None
        self._avg_mass  = None

        self._setup()

    def _setup(self):
        """ precalculate internal values """
        self._coefficients = np.ones(self._n_intervals)
        self._norm = 1.
        self._mass_norm = 1.

        for i in range(1, self._n_intervals):
            self._coefficients[i]  = (self._coefficients[i - 1])
            self._coefficients[i] *= (self.intervals[i] ** (self.slopes[i - 1] - self.slopes[i]))

        xmin, xmax = np.min(self.intervals), np.max(self.intervals)
        self._norm = self.get_nstar(xmin, xmax)
        self._mass_norm = self.get_mstar(xmin, xmax)
        # Compute the average mass
        self._avg_mass = self.get_avg_mstar(xmin, xmax)

    def _moment(self, xmin, xmax, order=0):
        """ integrate the distribution m^order * dn / dm over xmin and xmax"""
        x = self.intervals
        a = self.slopes
        b = a + 1 + order
        c = self._coefficients

        # analytical integration of a power law
        val = 0.
        for i, (xlo, xhi, bi, ci) in enumerate(zip(x[:-1], x[1:], b, c)):
            if (xmin < xhi) & (xmax > xlo):
                x0 = xmin if xlo <= xmin else xlo
                x1 = xmax if xhi >= xmax else xhi
                # careful if the index is -1
                if abs(bi) < 1e-30:
                    S = ci * np.log(x1 / x0)
                else:
                    S = ci / bi * ( (x1) ** (bi) - (x0) ** (bi) )
                val += S
        return val  / self._norm

    def get_nstar(self, xmin, xmax):
        """ Compute the number of stars """
        return self._moment(xmin, xmax, 0)

    def get_mstar(self, xmin, xmax):
        """ Compute the number of stars """
        return self._moment(xmin, xmax, 1)

    def get_avg_mstar(self, xmin, xmax):
        """ Compute the average mass """
        return (self.get_mstar(xmin, xmax) / self.get_nstar(xmin, xmax))

    def get_value(self, x):
        """ evaluate the IMF """
        x_ = np.atleast_1d(x)
        where = np.searchsorted(s.intervals, x_)
        y_ = np.zeros_like(x)
        ind = (where > 0) & (where < len(s.intervals))
        y_[ind] = s._coefficients[where[ind] - 1] * x[ind] ** s.slopes[where[ind] - 1]
        y_[~ind] = float('nan')
        return y_ / self._norm

    def __call__(self, x):
        return self.get_value(x)
class Salpeter(IMF):
    """ Salpeter IMF """
    def __init__(self):
        super().__init__(slopes=[-2.35],
                         intervals = [0.1, 120.],
                         name="Salpeter")

class Kroupa2001(IMF):
    """ Kroupa IMF """
    def __init__(self):
        super().__init__(slopes=[-0.3, -1.3, -2.3, -2.3],
                         intervals = [0.01, 0.08, 0.5, 1., 150.],
                         name="Kroupa 2001")

class Kennicutt(IMF):
    """ Kennicutt (1983) doi:10.1086/161261"""
    def __init__(self):
        super().__init__(slopes=[-1.4, -2.5],
                            intervals = [0.1, 1., 120.],
                            name="Kennicutt 1983")
s = Salpeter()
print('s._norm ~ 16.582 |', np.allclose(s._norm, 16.582, atol=1e-4))
print('s.get_nstar(0.1, 120) = 1 |', np.allclose(s.get_nstar(0.1, 120), 1., atol=1e-4))
print('s.get_avg_mstar(0.1, 120) ~ 0.354 |', np.allclose(s.get_avg_mstar(0.1, 120), 0.354, atol=1e-3))
s._norm ~ 16.582 | True
s.get_nstar(0.1, 120) = 1 | True
s.get_avg_mstar(0.1, 120) ~ 0.354 | True
s = Kroupa2001()
print('s._norm ~ 0.5033 |', np.allclose(s._norm, 0.5033, atol=1e-4))
print('s.get_nstar(0.01, 150) = 1 |', np.allclose(s.get_nstar(0.01, 150.), 1., atol=1e-4))
print('s.get_nstar(0.1, 120) = 1 |', np.allclose(s.get_nstar(0.1, 120.), 0.5553, atol=1e-4))
print('s.get_avg_mstar(0.01, 150) ~ 0.3838 |', np.allclose(s.get_avg_mstar(0.01, 150),  0.3838, atol=1e-4))
s._norm ~ 0.5033 | True
s.get_nstar(0.01, 150) = 1 | True
s.get_nstar(0.1, 120) = 1 | True
s.get_avg_mstar(0.01, 150) ~ 0.3838 | True
x = 10 ** np.linspace(np.log10(0.01), np.log10(150), 100)

for cl in (Salpeter, Kroupa2001, Kennicutt):
    s = cl()
    y = s(x)
    plt.loglog(x, y, '-', label=s.name)
plt.legend()
plt.ylabel(r'dN / dm')
plt.xlabel(r'mass m')
Text(0.5, 0, 'mass m')
../../_images/bfced1aadb27ccdf7076fd7cef21a63175977870b431cc391323b200af40096d.png