122 lines
5.1 KiB
Python
122 lines
5.1 KiB
Python
import itertools
|
|
import logging
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
from celluloid import Camera
|
|
from matplotlib import pyplot as plt
|
|
|
|
from calculations import DEFAULT_MORTALITY_DISTRIBUTION
|
|
|
|
|
|
def plot_mortality(mortality_distribution = DEFAULT_MORTALITY_DISTRIBUTION, subsidy_distribution = None,
|
|
covid_distribution = None):
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
plt.figure()
|
|
if subsidy_distribution is None:
|
|
plot_ = sns.barplot(np.arange(len(mortality_distribution)), mortality_distribution, color='cornflowerblue')
|
|
elif covid_distribution is None:
|
|
df = pd.DataFrame({"Ages": np.hstack((np.arange(len(mortality_distribution)), np.arange(len(mortality_distribution)))),
|
|
"label": list(itertools.chain(["Normal"]*len(mortality_distribution), ["With Subsidy"]*len(mortality_distribution))),
|
|
"Legend": np.hstack((mortality_distribution, subsidy_distribution))})
|
|
else:
|
|
df = pd.DataFrame({"Ages": np.hstack((np.arange(len(mortality_distribution)),
|
|
np.arange(len(mortality_distribution)),
|
|
np.arange(len(mortality_distribution)))),
|
|
"Case": list(itertools.chain(["Normal"] * len(mortality_distribution),
|
|
["With Subsidy"] * len(mortality_distribution),
|
|
["With Covid"] * len(mortality_distribution))),
|
|
"Scenario": np.hstack((mortality_distribution,
|
|
subsidy_distribution,
|
|
covid_distribution))})
|
|
plot_ = sns.lineplot(data=df, x="Ages", y="Scenario", hue="Case")
|
|
plt.title("Deregistration Distribution")
|
|
plt.ylabel("Recycling probability")
|
|
plt.xlabel("Age (years)")
|
|
#for ind, label in enumerate(plot_.get_xticklabels()):
|
|
# if ind % 10 == 0: # every 10th label is kept
|
|
# label.set_visible(True)
|
|
# else:
|
|
# label.set_visible(False)
|
|
plt.savefig("output/mortality.png")
|
|
plt.savefig("output/mortality.pgf")
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def plot_population(population_distribution: np.ndarray):
|
|
plt.figure()
|
|
plot_ = sns.barplot(np.arange(len(population_distribution)), population_distribution / 1e6, color='cornflowerblue')
|
|
plt.title("Population Distribution")
|
|
plt.ylabel("Vehicles (Millions)")
|
|
plt.xlabel("Age (years)")
|
|
for ind, label in enumerate(plot_.get_xticklabels()):
|
|
if ind % 10 == 0: # every 10th label is kept
|
|
label.set_visible(True)
|
|
else:
|
|
label.set_visible(False)
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def plot_population_development(years, total_pops: np.array):
|
|
# Plot total vehicle pop
|
|
plt.figure()
|
|
plot_ = sns.barplot(years, np.array(total_pops) / 1e6, color='cornflowerblue')
|
|
plt.title("Total Vehicle Population")
|
|
plt.ylabel("Vehicles (Millions)")
|
|
plt.xlabel("Simulation year")
|
|
plt.ylim(min(np.array(total_pops) / 1e6), max(np.array(total_pops) / 1e6))
|
|
for ind, label in enumerate(plot_.get_xticklabels()):
|
|
if ind % 10 == 0: # every 10th label is kept
|
|
label.set_visible(True)
|
|
else:
|
|
label.set_visible(False)
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def plot_production(years, production: np.array):
|
|
df = pd.DataFrame({"Year": np.hstack((years, years)),
|
|
"Production": np.hstack((production, [production[0]]*len(years))),
|
|
"Scenario": itertools.chain([f"Subsidy ({np.sum(production)/1e6})"]*len(years), [f"Business as Usual ({(production[0]/1e6)*len(years)})"]*len(years))})
|
|
df.Production /= 1e6
|
|
|
|
# Plot total vehicle pop
|
|
plt.figure()
|
|
plot_ = sns.barplot(data=df, x="Year", y="Production", hue="Scenario")
|
|
plt.title("Total Vehicle Production")
|
|
plt.ylabel("Vehicles (Millions)")
|
|
plt.xlabel("Simulation year")
|
|
plt.ylim(0.8*min(np.array(production) / 1e6), 1.2*max(np.array(production) / 1e6))
|
|
for ind, label in enumerate(plot_.get_xticklabels()):
|
|
if ind % 10 == 0: # every 10th label is kept
|
|
label.set_visible(True)
|
|
else:
|
|
label.set_visible(False)
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def animate(years, population_distributions):
|
|
ages = np.arange(len(population_distributions[0]))
|
|
fig = plt.figure()
|
|
camera = Camera(fig)
|
|
# animation draws one data point at a time
|
|
for i in range(len(population_distributions)):
|
|
plot_ = sns.barplot(ages, population_distributions[i] / 1e6, color='cornflowerblue')
|
|
plt.plot(ages, np.array(population_distributions[0]) / 1e6, color='blue')
|
|
for ind, label in enumerate(plot_.get_xticklabels()):
|
|
if ind % 10 == 0: # every 10th label is kept
|
|
label.set_visible(True)
|
|
else:
|
|
label.set_visible(False)
|
|
plt.title(f"Total Vehicle Population")
|
|
plt.ylabel("Vehicles (Millions)")
|
|
plt.xlabel("Vehicle Age")
|
|
camera.snap()
|
|
|
|
anim = camera.animate(blit=False)
|
|
anim.save('animation.mp4') |