# Matplotlib exercises

### 1) Some simple plots

#### Related manual sections: Simple plots, Histograms

Run the following code cell to generate some simple datasets. Try to understand what they're generating.

In [None]:
import numpy as np
from numpy.random import default_rng
import matplotlib.pyplot as plt
from scipy.stats import gumbel_r, gumbel_l, norm

rng = default_rng()

t = np.arange(0.0, 10, 0.1)  # Time (seconds)
sine = 2 * np.sin(np.pi * t)  # Voltage (mV)

x = rng.gumbel(loc=37, scale=0.5, size=365)  # Body temperature measurements (Celsius)

hour = np.linspace(0, 24, 500)  # Hour of the day
td = 100 * np.dot(
    np.array([2, 0.3, 0.7]).T,
    [
        gumbel_r.pdf(hour, loc=7, scale=2),
        norm.pdf(hour, loc=15, scale=0.7),
        gumbel_l.pdf(hour, loc=19, scale=0.8),
    ],
)
td += (np.max(td) / 10) * np.abs(
    rng.normal(scale=2, size=td.size),
)  # Traffic density (N_cars / km)

Using these datasets, create the following plots:

1. A sinusoid line plot showing mV oscillation over time in seconds.
2. A histogram that shows the different body temperature measurements.
3. A scatter plot with traffic density as a function of the time of day.

Your output should look like this:  

![](https://lund-observatory-teaching.github.io/lundpython/imgs/figures.jpeg)

Write your solution:

<details>
  <summary><b>Click to reveal solution</b></summary>

```python
plt.figure()
plt.plot(t, sine)
plt.show()

plt.figure()
plt.hist(x)
plt.show()

plt.figure()
plt.scatter(hour, td)
plt.show()
```
  
</details>

<hr style="border:1.5px solid gray"></hr>

### 2) Improving plots

#### Related manual sections: Simple plots, Histograms

We made some improvements to the plot:

![](https://lund-observatory-teaching.github.io/lundpython/imgs/figures_improved.jpeg)

Every plot should include these changes we made: 
- Add a title.
- Add labels with good font sizes.
- Add units to labels.

Do these required changes. If time permits, you might also want to investigate these additional changes made:

- Change sine color to red.
- Add grid to sine plot. 
- Change histogram bin number
- Change scatter plot marker type, size, and color. 
- Change figure sizes.
- Change tick labels and their font sizes.

<details>
  <summary>Click here for a hint!</summary>

Look at the documentation for each of the plotting functions to see what options they have. 
    
You can also take a look at [Matplotlib's sample plots](https://matplotlib.org/stable/gallery/index.html) to see how they do it.
    
</details>

Write your solution:

<details>
  <summary><b>Click to reveal solution</b></summary>

```python
# Your solution

fs = 16
plt.figure(figsize=(6, 3))
plt.plot(t, sine, "r")
plt.xlabel("Time [s]", fontsize=fs)
plt.xticks(fontsize=fs - 2)
plt.xlim([0, 10])
plt.ylabel("Voltage [mV]", fontsize=fs)
plt.yticks(fontsize=fs - 2)
plt.ylim([-2.2, 2.2])
plt.title("AC Transistor", fontsize=fs)
plt.grid()
plt.savefig("sine_good.png", bbox_inches = "tight")

plt.figure(figsize=(5, 5))
plt.hist(x, bins=20, histtype="stepfilled")
plt.xlabel("Temperature [$\degree$C]", fontsize=fs)
plt.xticks(np.arange(36, 41, 1),fontsize=fs - 2)
plt.ylabel("$N_\mathrm{Days}$", fontsize=fs)
plt.yticks(fontsize=fs - 2)
plt.title("My body temperature over a year", fontsize=fs)
plt.savefig("hist_good.png")

plt.figure(figsize=(8, 4))
plt.scatter(hour, td, s=2, c="k")
plt.xlabel("Time [Hours]", fontsize=fs)
plt.xticks(np.arange(0, 25, 3),fontsize=fs - 2)
plt.xlim([0, 24])
plt.ylabel("Traffic density [N$_\mathrm{cars}$ / km]", fontsize=fs)
plt.yticks(fontsize=fs - 2)
plt.ylim([0, 50])
plt.title("How does traffic vary over a day?", fontsize=fs)
plt.savefig("scatter_good.png")
```
  
</details>

<hr style="border:1.5px solid gray"></hr>

### 3) Subplots

#### Related manual sections:  Multidimensional data, Multidimensional histograms

Run the following code cell which generates two arrays `x` and `y` of shape (5, 100000) each.

In [None]:
x = np.zeros((5, 100000))
y = np.zeros((5, 100000))
for i in range(5):
    gen_mean = rng.uniform(low=-3, high=3, size=(3, 2))
    gen_cov = rng.uniform(low=0, high=1, size=(3, 2))
    x[i], y[i] = np.vstack(
        (
            rng.multivariate_normal(
                mean=gen_mean[0],
                cov=[[gen_cov[0, 0], 0], [0, gen_cov[0, 1]]],
                size=40000,
            ),
            rng.multivariate_normal(
                mean=gen_mean[1],
                cov=[[gen_cov[1, 0], 0], [0, gen_cov[1, 1]]],
                size=20000,
            ),
            rng.multivariate_normal(
                mean=gen_mean[2],
                cov=[[gen_cov[2, 0], 0], [0, gen_cov[2, 1]]],
                size=40000,
            ),
        ),
    ).T

Use the data to create a 2x3 grid of subplots. In each subplot, include one of the 5 rows of 100 000 randomly generated pairs of `x, y` points. Be aware that since the data is randomly generated, your plots will look slightly different.  

Your figure might look a little different from the example below in addition to the different data. If time permits, you can investigate making your figure more similar by:

- Setting the x and y limits of your plots.
- Setting your ticklabel fontsizes correctly.
- Removing the subplot in the 2nd row and 3rd column without data.

*Note:* It is also possible to loop over the indices for rows and columns when doing subplots.  

<br>
  <details>
    <summary>Click here for hints!</summary>
  
      
  The following documentation pages will help you:
      
  [`matplotlib.axes.Axes.hexbin()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.hexbin.html)  
      
  [`plt.subplots()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html)

  [`matplotlib.axes.Axes.axis`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.axis.html) (How to remove a subplot)
    
  [`plt.tight_layout()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.tight_layout.html) (Can remove excess whitespace)  
    
  [`Multiple subplots`](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subplot.html)  
      

  
      
  </details>
  
  ![](https://lund-observatory-teaching.github.io/lundpython/imgs/subplots.png)

Write your solution:

<details>
  <summary><b>Click to reveal solution</b></summary>

```python
fig, ax = plt.subplots(nrows=2 , ncols=3, figsize=(12, 8))

k = 0
for i in range(2):
    for j in range(3):        
        if (i, j) != (1, 2): 
            x_ax, y_ax = x[k], y[k]
            ax[i, j].hexbin(x_ax, y_ax, gridsize=20)
            ax[i, j].set(xlim=(x_ax.min(), x_ax.max()), ylim=(y_ax.min(), y_ax.max()))
            ax[i, j].set_ylabel("$y$", fontsize=14)
            ax[i, j].set_xlabel("$x$", fontsize=14)
            ax[i, j].tick_params(labelsize=14)
            k += 1
        else:
            ax[i, j].axis("off")
            
plt.tight_layout()
plt.show()
```
  
</details>