This repository has been archived on 2025-02-17. You can view files and clone it, but cannot push or open issues or pull requests.
BreCal/src/lib_brecal_utils/brecal_utils/database/sql_handler.py

201 lines
8.6 KiB
Python

import numpy as np
import pandas as pd
import datetime
from BreCal.schemas.model import Shipcall, Ship, Participant, Berth, User, Times
from brecal_utils.database.enums import ParticipantType
class SQLHandler():
"""
An object that reads SQL queries from the sql_connection and stores it in pandas DataFrames. The object can read all available tables
at once into memory, when providing 'read_all=True'.
# #TODO_initialization: shipcall_tug_map, user_role_map & role_securable_map might be mapped to the respective dataframes
"""
def __init__(self, sql_connection, read_all=False):
self.sql_connection = sql_connection
self.all_schemas = self.get_all_schemas_from_mysql()
self.build_str_to_model_dict()
if read_all:
self.read_all(self.all_schemas)
def get_all_schemas_from_mysql(self):
with self.sql_connection.cursor(buffered=True) as cursor:
cursor.execute("SHOW TABLES")
schema = cursor.fetchall()
all_schemas = [schem[0] for schem in schema]
return all_schemas
def build_str_to_model_dict(self):
"""
creates a simple dictionary, which maps a string to a data object
e.g.,
'ship'->BreCal.schemas.model.Ship object
"""
self.str_to_model_dict = {
"shipcall":Shipcall, "ship":Ship, "participant":Participant, "berth":Berth, "user":User, "times":Times
}
return
def read_mysql_table_to_df(self, table_name:str):
"""determine a {table_name}, which will be read from a mysql server. returns a pandas DataFrame with the respective data"""
df = pd.read_sql(sql=f"SELECT * FROM {table_name}", con=self.sql_connection)
return df
def mysql_to_df(self, query):
"""provide an arbitrary sql query that should be read from a mysql server {sql_connection}. returns a pandas DataFrame with the obtained data"""
df = pd.read_sql(query, self.sql_connection).convert_dtypes()
df = df.set_index('id', inplace=False) # avoid inplace updates, so the raw sql remains unchanged
return df
def read_all(self, all_schemas):
# create a dictionary, which maps every mysql schema to pandas DataFrames
self.df_dict = self.build_full_mysql_df_dict(all_schemas)
# update the 'participants' column in 'shipcall'
self.initialize_shipcall_participant_list()
return
def build_full_mysql_df_dict(self, all_schemas):
"""given a list of strings {all_schemas}, every schema will be read as individual pandas DataFrames to a dictionary with the respective keys. returns: dictionary {schema_name:pd.DataFrame}"""
mysql_df_dict = {}
for schem in all_schemas:
query = f"SELECT * FROM {schem}"
mysql_df_dict[schem] = self.mysql_to_df(query)
return mysql_df_dict
def initialize_shipcall_participant_list(self):
"""
iteratively applies the .get_participants method to each shipcall.
the function updates the 'participants' column.
"""
# 1.) get all shipcalls
df = self.df_dict.get('shipcall')
# 2.) iterate over each individual shipcall, obtain the id (pandas calls it 'name')
# and apply the 'get_participants' method, which returns a list
# if the shipcall_id exists, the list contains ids
# otherwise, return a blank list
df['participants'] = df.apply(
lambda x: self.get_participants(x.name),
axis=1)
return
def standardize_model_str(self, model_str:str)->str:
"""check if the 'model_str' is valid and apply lowercasing to the string"""
model_str = model_str.lower()
assert model_str in list(self.df_dict.keys()), f"cannot find the requested 'model_str' in mysql: {model_str}"
return model_str
def get_data(self, id:int, model_str:str):
"""
obtains {id} from the respective mysql database and builds a data model from that.
the id should match the 'id'-column in the mysql schema.
returns: data model, such as Ship, Shipcall, etc.
e.g.,
data = self.get_data(0,"shipcall")
returns a Shipcall object
"""
model_str = self.standardize_model_str(model_str)
df = self.df_dict.get(model_str)
data = self.df_loc_to_data_model(df, id, model_str)
return data
def get_all(self, model_str:str)->list:
"""
given a model string (e.g., 'shipcall'), return a list of all
data models of that type from the sql
"""
model_str = self.standardize_model_str(model_str)
all_ids = self.df_dict.get(model_str).index
all_data = [
self.get_data(_aid, model_str)
for _aid in all_ids
]
return all_data
def df_loc_to_data_model(self, df, id, model_str, loc_type:str="loc"):
assert len(df)>0, f"empty dataframe"
# get a pandas series from the dataframe
series = df.loc[id] if loc_type=="loc" else df.iloc[id]
# get the respective data model object
data_model = self.str_to_model_dict.get(model_str,None)
assert data_model is not None, f"could not find the requested model_str: {model_str}"
# build 'data' and fill the data model object
data = {**{'id':id}, **series.to_dict()} # 'id' must be added manually, as .to_dict does not contain the index, which was set with .set_index
data = data_model(**data)
return data
def get_times_for_participant_type(self, df_times, participant_type:int):
filtered_series = df_times.loc[df_times["participant_type"]==participant_type]
assert len(filtered_series)<=1, f"found multiple results"
times = self.df_loc_to_data_model(filtered_series, id=0, model_str='times', loc_type="iloc") # use iloc! to retrieve the first result
return times
def dataframe_to_data_model_list(self, df, model_str)->list:
model_str = self.standardize_model_str(model_str)
all_ids = df.index
all_data = [
self.df_loc_to_data_model(df, _aid, model_str)
for _aid in all_ids
]
return all_data
def get_participants(self, shipcall_id:id)->list:
"""
given a {shipcall_id}, obtain the respective list of participants.
when there are no participants, return a blank list
returns: participant_id_list, where every element is an int
"""
df = self.df_dict.get("shipcall_participant_map")
df = df.set_index('shipcall_id', inplace=False)
# the 'if' call is needed to ensure, that no Exception is raised, when the shipcall_id is not present in the df
participant_id_list = df.loc[shipcall_id, "participant_id"].to_list() if shipcall_id in list(df.index) else []
return participant_id_list
def get_times_of_shipcall(self, shipcall)->pd.DataFrame:
df_times = self.df_dict.get('times') # -> pd.DataFrame
df_times = df_times.loc[df_times["shipcall_id"]==shipcall.id]
return df_times
def get_times_for_agency(self, non_null_column=None)->pd.DataFrame:
"""
options:
non_null_column:
None or str. If provided, the 'non_null_column'-column of the dataframe will be filtered,
so only entries with provided values are returned (filters all NaN and NaT entries)
"""
# get all times
df_times = self.df_dict.get('times') # -> pd.DataFrame
# filter out all NaN and NaT entries
if non_null_column is not None:
df_times = df_times.loc[~df_times[non_null_column].isnull()] # NOT null filter
# filter by the agency participant_type
times_agency = df_times.loc[df_times["participant_type"]==ParticipantType.AGENCY.value]
return times_agency
def filter_df_by_key_value(self, df, key, value)->pd.DataFrame:
return df.loc[df[key]==value]
def get_unique_ship_counts(self, all_df_times:pd.DataFrame, query:str, rounding:str="min", maximum_threshold=3):
"""given a dataframe of all agency times, get all unique ship counts, their values (datetime) and the string tags. returns a tuple (values,unique,counts)"""
# get values and optional: rounding
values = all_df_times.loc[:, query]
if rounding is not None:
values = values.dt.round(rounding) # e.g., 'min'
unique, counts = np.unique(values, return_counts=True)
violation_state = np.any(np.greater(counts, maximum_threshold))
return (values, unique, counts)