# A brief introduction to plotting

This notebook will provide useful examples of producing plots with [Matplotlib](https://matplotlib.org/), but do keep in mind that many of the plots below are meant to demonstrate something specific and might therefore be lacking components that should normally be present.

Experimenting with the code is encouraged.


## `matplotlib`

### Simple plots

Most often you will encounter data plotted on a 2D grid as individual data points, a line that goes through those points or both. Simple examples of the three cases are presented below.

In [None]:
# These are normally needed
import numpy as np
import matplotlib.pyplot as plt


def simple_function(x):
    return x**2 - x


# Generate random data for plotting
x = np.linspace(0, 5, 100)
rng = np.random.default_rng()
y = x**2 - x + rng.standard_normal(x.size)

plt.scatter(x, y)
plt.show()

plt.plot(x, y)
plt.show()

plt.plot(x, y, marker="o")
plt.show()

The good thing about the plots above is that they were quick and simple to make. The bad thing about them is that they are missing too much information to be useful. But before we move on to improving the plots we should have an overview of basic terminology, so we start by labelling the different components of a plot.

In [None]:
# Much of the code below is not normally required
from matplotlib.ticker import AutoMinorLocator, FuncFormatter

fig, ax = plt.subplots()
ax.scatter(x, y, label="Markers")
ax.plot(x, simple_function(x), label="Line", color="r")
ax.legend(title="Legend")
ax.set_title("Title")
ax.set_xlabel("x-axis label")
ax.set_ylabel("y-axis label")
secax = ax.secondary_yaxis("right")
secax.yaxis.set_minor_locator(AutoMinorLocator(2))
secax.yaxis.set_minor_formatter(FuncFormatter(lambda x, y: "Minor ticks"))
secax.yaxis.set_major_formatter(FuncFormatter(lambda x, y: "Major ticks"))
secax.tick_params(which="both", width=2)
secax.tick_params(which="major", length=10)
secax.tick_params(which="minor", length=5)
plt.show()


In `matplotlib` minor ticks are hidden by default and major ticks are usually automatically placed at reasonable intervals.
However the labels in the legend, on the *x*- and *y*-axes and the title must be provided by the user.
Specifying the title and axes labels is not complicated.

In [None]:
plt.plot(x, y)
plt.title("Measurement results")
plt.xlabel("Time [s]")
plt.ylabel(r"Speed [m$\,$s$^{-1}$]")
plt.show()

As you can see it is possible to use simple $\TeX$ commands in the labels.

The labels that will appear in the legend can be specified when the data is plotted.

In [None]:
plt.plot(x, simple_function(x), label="Model")
plt.plot(x, y, label="Data")
plt.legend()
plt.show()

The data can come in a few different formats. One option is to provide two 1D arrays of equal length that specify $x$ and $y$ coordinates in multiple `plt.plot()` calls.

In [None]:
# Generate more data
offsets = np.arange(5)
y_mult = offsets + y[:, np.newaxis]

plt.plot(x, y_mult[:, 0], label=offsets[0])
plt.plot(x, y_mult[:, -1], label=offsets[-1])
plt.legend(title="Offset")
plt.show()

It is also possible to provide *x* as a 1D array and *y* as a 2D array that has the same length in the first dimension as *x*. In that case it is more convenient to provide the labels directly in the `plt.legend()` call.

In [None]:
print(f"Shape of x:{x.shape}.")
print(f"Shape of y_mult:{y_mult.shape}.")
plt.plot(x, y_mult)
plt.legend(title="Offsets", labels=offsets)
plt.show()

If *x*- and *y*-coordinates are both provided as 2D arrays they need to have the same shape.

In [None]:
x_mult = offsets + x[:, np.newaxis]

print(f"Shape of x_mult:{x_mult.shape}.")
print(f"Shape of y_mult:{y_mult.shape}.")
plt.plot(x_mult, y_mult)
plt.legend(title="Offsets", labels=offsets)
plt.show()

`matplotlib` will attempt to place the legend so that it will not cover the data, but does not always succeed.
Although it is possible to place the legend in an arbitrary position, simply using it the `loc`keyword can often provide a quick solution.

In [None]:
for loc in ("upper left", "center right"):
    plt.plot(x, y, label="data")
    plt.legend(loc=loc)
    plt.title(f"Legend in {loc}")
    plt.show()

As you have already seen `matplotlib` will automatically cycle through colors so that different data sets are distinguishable.
But it is also possible to specify line or marker color manually, and likewise with many other properties.

In [None]:
for i, color, linestyle in zip(range(4), ("b", "y", "r", "k"), ("-", "-.", ":", "--")):
    plt.plot(x, y + 5 * i, color=color, linestyle=linestyle, label=i)
plt.legend()
plt.show()

It is also possible to use a less verbose, though potentially more obscure specification.

In [None]:
plt.plot(x, y, "*y")
plt.show()

The size of the figure can be changed.

In [None]:
for i in (2, 4, 6):
    plt.figure(figsize=(i, i))
    plt.scatter(x, y)
    plt.show()

It is also possible to limit the range of data included.

In [None]:
plt.plot(x, y)
plt.xlim((1, 3))
plt.show()

### Errorbars and confidence intervals

Measured data is not perfectly well known at this uncertainty should be reflected on the plots. The uncertainties of discrete points can be represented with error bars. 

In [None]:
# Only plot some of the data
mask = np.arange(0, x.size, 10, dtype=int)
# We set the errors in x to be constant and in y to scale with x
plt.errorbar(x[mask], y[mask], xerr=0.2, yerr=x[mask], marker="o", linestyle="")
plt.show()

It is quite common to have many estimates of *y* for every value of *x* where the uncertainty in *x* is much smaller than in *y*. It is then often a good idea to plot the mean of the *y* values for each *x* as a line and multiples of the standard deviations of *y* as a shaded region around the line to mark the confidence intervals.

In [None]:
# Generate many random samples
n_samples = 80
errors = rng.standard_normal((n_samples, x.size))
errors *= np.linspace(0.2, 2, x.size)
y_mult = simple_function(x) + errors

y_mean = np.mean(y_mult, axis=0)
y_std = np.std(y_mult, axis=0, ddof=1)
for i in range(1, 4):
    plt.fill_between(
        x,
        y1=y_mean + i * y_std,
        y2=y_mean - i * y_std,
        alpha=0.25 - 0.05 * i,
        color="b",
        label=rf"${i}\sigma$",
    )
plt.plot(x, y_mean)
plt.legend(title="Confidence intervals")
plt.show()

### Histograms

Creating histograms is as straightforward as you might expect.

In [None]:
plt.hist(y)
plt.title("A histogram")
plt.xlabel("y")
plt.ylabel("Number of occurrences in bins")
plt.show()

### Multidimensional data

One option for displaying multidimensional data is to show it on multiple subplots. This is best achieved using `plt.subplots()` that gives us a grid of `Axes` objects that can be worked with individually.

In [None]:
# Generate even more random data
z = rng.standard_normal(x.size)

fig, ax = plt.subplots(2, 2)

ax[0, 0].scatter(x, y)
ax[0, 0].set_xlabel("x")
ax[0, 0].set_ylabel("y")

ax[0, 1].scatter(z, y)
ax[0, 1].set_xlabel("z")
ax[0, 1].set_ylabel("y")

ax[1, 1].scatter(z, x)
ax[1, 1].set_xlabel("z")
ax[1, 1].set_ylabel("x")

ax[1, 0].remove()

fig.tight_layout()
plt.show()

But it is also possible to convey multidimensional data in a single scatter plot by making the marker sizes dependent on the third variable.

In [None]:
fig, ax = plt.subplots()
scatter = ax.scatter(x, z, s=10 * (y + 5))
ax.set_xlabel("x")
ax.set_ylabel("z")
plt.legend(*scatter.legend_elements(prop="sizes"), title=r"$10y+5$")
plt.show()

Unfortunately the legend covers the data and the datapoints also overlap. The plot below addresses these shortcomings.

In [None]:
fig, ax = plt.subplots()
scatter = ax.scatter(x, z, s=10 * (y + 5), alpha=0.5)
ax.set_xlabel("x")
ax.set_ylabel("z")
plt.legend(
    *scatter.legend_elements(prop="sizes"),
    bbox_to_anchor=(1.02, 1),
    title=r"$10y+5$",
)
plt.show()

Another option would be to use color coding.

In [None]:
plt.scatter(x, z, c=y)
plt.xlabel("x")
plt.ylabel("z")
plt.colorbar(label="y")
plt.show()

### Multidimensional histograms

It is also possible to create 2D histograms. We can choose between rectangular and hexagonal bins.

In [None]:
plt.hist2d(x, y)
plt.colorbar(label="Number of occurrences in a bin")
plt.show()

plt.hexbin(x, y, gridsize=10)
plt.colorbar(label="Number of occurrences in a bin")
plt.show()

## Grids of data

Often enough the data consists of values at points regularly placed on some 2D grid. A simple option for plotting such data is colored pixels in a image. 

In [None]:
# Pixel width in x
resolution = 1
x = np.arange(-5, 5 + resolution, resolution)
y = 2 * x
# Create 2D arrays from x and y
xx, yy = np.meshgrid(x, y)
z = np.sqrt((xx - 2) ** 2 + (yy - 4) ** 2)

plt.imshow(z, interpolation=None, extent=(x[0], x[-1], y[0], y[-1]), origin="lower")
plt.colorbar()
plt.show()

But sometimes it might be preferable to draw contour plots. The *x* and *y* coordinates can be given as 1D or 2D arrays.

In [None]:
# Contour lines
plt.contour(xx, yy, z)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.show()

# Filled contours
plt.contourf(x, y, z)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.show()

# A mixture of both
plt.contour(x, y, z, colors="k")
plt.contourf(xx, yy, z)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.show()

### Map projections

So far we have made plots using a simple rectangular projections. If we are plotting objects on the sky then we might want to use a different projection. There are several map projections available, but somewhat inconveniently they are all geographic projections with north pointing up and east pointing right rather than astronomical projections with north pointing up and east pointing left.

The plots below demonstrate how the Galactic midplane looks like in the different available projections. 

In [None]:
from astropy import units as u
from astropy.coordinates import SkyCoord, Galactic, ICRS

# Generate points along the Galactic midplane and convert their coordinates to ICRS
midplane = SkyCoord(frame=Galactic, l=np.arange(-180, 181) * u.deg, b=0 * u.deg)
midplane = midplane.transform_to(ICRS)

for projection in ("rectilinear", "aitoff", "hammer", "mollweide", "lambert"):
    plt.axes(projection=projection)
    # Sadly we have to handle units explicitly here
    plt.scatter((midplane.ra - 180 * u.deg).rad, midplane.dec.rad)
    plt.title(projection.title())
    plt.grid(True)
    plt.show()

## `astropy` `Quantity`

The following example illustrates how to use the `astropy` `Quantity` class together with `matplotlib` to ensure that the values on a plot are always in the units that the labels claim them to be in.

In [None]:
from astropy.table import QTable

# Create a QTable
labels = ["Earth", "Jupiter", "Sun"]
m = [1 * u.M_earth, 1 * u.M_jupiter, 1 * u.M_sun]
r = [1 * u.R_earth, 1 * u.R_jupiter, 1 * u.R_sun]
astrodata = QTable((labels, m, r), names=["name", "mass", "radius"])
astrodata["density"] = (
    astrodata["mass"] / (4 * np.pi / 3 * astrodata["radius"] ** 3)
).to(u.g / u.cm**3)

# Plot the radii and masses with two different mass units
for munit in (u.kg, u.M_earth):
    astrodata["mass"] = astrodata["mass"].to(munit)
    for elem in astrodata:
        plt.loglog(elem["mass"], elem["radius"], "o", label=elem["name"])
    # Pay attention to the use of single and double quotation marks
    plt.xlabel(f'Mass [{astrodata["mass"].unit.to_string("latex")}]')
    plt.ylabel(f'Radius [{astrodata["radius"].unit.to_string("latex")}]')
    plt.legend()
    plt.show()

for str_format in ("latex", "latex_inline"):
    for elem in astrodata:
        plt.semilogx(elem["mass"], elem["density"], "o", label=elem["name"])
    # Pay attention to the use of single and double quotation marks
    plt.xlabel(f'Mass [{astrodata["mass"].unit.to_string(str_format)}]')
    plt.ylabel(f'Density [{astrodata["density"].unit.to_string(str_format)}]')
    plt.legend()
    plt.show()

## seaborn

Although `matplotlib` allows us to have a lot of control over what the plots will look like, it might not be the best tool for quickly producing plots of multidimensional data.
In that case it might be preferable to use [seaborn](https://seaborn.pydata.org/).
Because `seaborn` is built on top of `matplotlib` it is possible to tweak `seaborn` plots with `matplotlib` commands.

Although `seaborn` is designed to work with [pandas](https://pandas.pydata.org/) `DataFrame` instead of `astropy` `Table` or `QTable`, the `astropy` `to_pandas()` function allows for quick conversion.

Although the benefits of using seaborn are not necessarily apparent from the small example below, it can be very useful if the data set has many elements and dimensions.

In [None]:
import pandas as pd
import seaborn as sns

# Convert the QTable to a DataFrame
df = astrodata.to_pandas()

sns.relplot(
    x="mass",
    y="radius",
    hue="density",
    style="name",
    data=df,
    s=200,
    palette="dark",
)
# Use Matplotlib calls to convert axes to logarithmic scale
plt.xscale("log")
plt.yscale("log")

## Customizing Matplotlib

There are many ways of changing the style of a figure in Matplotlib.
Here we will go through a few.

#### 1) Function arguments
As has been demonstrated above, to change the properties of one individual element one can pass the setting directly to that element. Examples include
```python
plt.plot(x, y, color="red", linestyle="-.")
```
and
```python
plt.xlabel("Temperature [$\degree$C]", fontsize=16)
```

#### 2) Context managers
If you want to change a setting for a block of code but not the whole file, a context manager can be used.
For example, if you want to make this figure

In [None]:
plt.plot(x, y, label="My data")
plt.xlim((1, 3))
plt.legend()
plt.show()

but you want much bigger font for all the text, you can use

In [None]:
with plt.style.context({"font.size": 20}):
    plt.plot(x, y, label="My data")
    plt.xlim((1, 3))
    plt.legend()
    plt.show()

The changed setting is applied to all code inside the `with` block but then automatically reset outside it.

#### 3) rcParams
It is possible to set the runtime configuration (rc) settings for Matplotlib in your script.
This is done by writing
```python
import matplotlib as mpl
```
and setting an entry in `mpl.rcParams` to the desired value, for example
```python
mpl.rcParams['lines.linewidth'] = 2
```
These changes then remain in effect until they are changed or the program is exited.
The Matplotlib documentation contains [a full list of parameters that can be changed this way](https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.rcParams).

#### 4) Style sheets
In the presentation, we used a style sheet
```python
plt.style.use(
    "https://lund-observatory-teaching.github.io/lundpython/3-plotting/presentation.mplstyle"
)
```
A style sheet is a document containing settings for Matplotlib. By loading a style sheet with the command shown above the settings in the style sheet will be applied to all subsequent plots.
Matplotlib also provides [a list of ready-made style sheets](https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html) for anyone to use.
It is also possible to [create your own style sheet](https://matplotlib.org/stable/users/explain/customizing.html#customizing-with-style-sheets).

Additionally, feel free to consult the [official documentation](https://matplotlib.org/stable/users/explain/customizing.html) at any time for even more ways of changing styles.