201 lines
8.6 KiB
Python
201 lines
8.6 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import datetime
|
|
from BreCal.schemas.model import Shipcall, Ship, Participant, Berth, User, Times
|
|
from brecal_utils.database.enums import ParticipantType
|
|
|
|
class SQLHandler():
|
|
"""
|
|
An object that reads SQL queries from the sql_connection and stores it in pandas DataFrames. The object can read all available tables
|
|
at once into memory, when providing 'read_all=True'.
|
|
|
|
# #TODO_initialization: shipcall_tug_map, user_role_map & role_securable_map might be mapped to the respective dataframes
|
|
"""
|
|
def __init__(self, sql_connection, read_all=False):
|
|
self.sql_connection = sql_connection
|
|
self.all_schemas = self.get_all_schemas_from_mysql()
|
|
self.build_str_to_model_dict()
|
|
|
|
if read_all:
|
|
self.read_all(self.all_schemas)
|
|
|
|
def get_all_schemas_from_mysql(self):
|
|
with self.sql_connection.cursor(buffered=True) as cursor:
|
|
cursor.execute("SHOW TABLES")
|
|
schema = cursor.fetchall()
|
|
all_schemas = [schem[0] for schem in schema]
|
|
return all_schemas
|
|
|
|
def build_str_to_model_dict(self):
|
|
"""
|
|
creates a simple dictionary, which maps a string to a data object
|
|
e.g.,
|
|
'ship'->BreCal.schemas.model.Ship object
|
|
"""
|
|
self.str_to_model_dict = {
|
|
"shipcall":Shipcall, "ship":Ship, "participant":Participant, "berth":Berth, "user":User, "times":Times
|
|
}
|
|
return
|
|
|
|
def read_mysql_table_to_df(self, table_name:str):
|
|
"""determine a {table_name}, which will be read from a mysql server. returns a pandas DataFrame with the respective data"""
|
|
df = pd.read_sql(sql=f"SELECT * FROM {table_name}", con=self.sql_connection)
|
|
return df
|
|
|
|
def mysql_to_df(self, query):
|
|
"""provide an arbitrary sql query that should be read from a mysql server {sql_connection}. returns a pandas DataFrame with the obtained data"""
|
|
df = pd.read_sql(query, self.sql_connection).convert_dtypes()
|
|
df = df.set_index('id', inplace=False) # avoid inplace updates, so the raw sql remains unchanged
|
|
return df
|
|
|
|
def read_all(self, all_schemas):
|
|
# create a dictionary, which maps every mysql schema to pandas DataFrames
|
|
self.df_dict = self.build_full_mysql_df_dict(all_schemas)
|
|
|
|
# update the 'participants' column in 'shipcall'
|
|
self.initialize_shipcall_participant_list()
|
|
return
|
|
|
|
def build_full_mysql_df_dict(self, all_schemas):
|
|
"""given a list of strings {all_schemas}, every schema will be read as individual pandas DataFrames to a dictionary with the respective keys. returns: dictionary {schema_name:pd.DataFrame}"""
|
|
mysql_df_dict = {}
|
|
for schem in all_schemas:
|
|
query = f"SELECT * FROM {schem}"
|
|
mysql_df_dict[schem] = self.mysql_to_df(query)
|
|
return mysql_df_dict
|
|
|
|
def initialize_shipcall_participant_list(self):
|
|
"""
|
|
iteratively applies the .get_participants method to each shipcall.
|
|
the function updates the 'participants' column.
|
|
"""
|
|
# 1.) get all shipcalls
|
|
df = self.df_dict.get('shipcall')
|
|
|
|
# 2.) iterate over each individual shipcall, obtain the id (pandas calls it 'name')
|
|
# and apply the 'get_participants' method, which returns a list
|
|
# if the shipcall_id exists, the list contains ids
|
|
# otherwise, return a blank list
|
|
df['participants'] = df.apply(
|
|
lambda x: self.get_participants(x.name),
|
|
axis=1)
|
|
return
|
|
|
|
def standardize_model_str(self, model_str:str)->str:
|
|
"""check if the 'model_str' is valid and apply lowercasing to the string"""
|
|
model_str = model_str.lower()
|
|
assert model_str in list(self.df_dict.keys()), f"cannot find the requested 'model_str' in mysql: {model_str}"
|
|
return model_str
|
|
|
|
def get_data(self, id:int, model_str:str):
|
|
"""
|
|
obtains {id} from the respective mysql database and builds a data model from that.
|
|
the id should match the 'id'-column in the mysql schema.
|
|
returns: data model, such as Ship, Shipcall, etc.
|
|
|
|
e.g.,
|
|
data = self.get_data(0,"shipcall")
|
|
returns a Shipcall object
|
|
"""
|
|
model_str = self.standardize_model_str(model_str)
|
|
|
|
df = self.df_dict.get(model_str)
|
|
data = self.df_loc_to_data_model(df, id, model_str)
|
|
return data
|
|
|
|
def get_all(self, model_str:str)->list:
|
|
"""
|
|
given a model string (e.g., 'shipcall'), return a list of all
|
|
data models of that type from the sql
|
|
"""
|
|
model_str = self.standardize_model_str(model_str)
|
|
all_ids = self.df_dict.get(model_str).index
|
|
|
|
all_data = [
|
|
self.get_data(_aid, model_str)
|
|
for _aid in all_ids
|
|
]
|
|
return all_data
|
|
|
|
def df_loc_to_data_model(self, df, id, model_str, loc_type:str="loc"):
|
|
assert len(df)>0, f"empty dataframe"
|
|
|
|
# get a pandas series from the dataframe
|
|
series = df.loc[id] if loc_type=="loc" else df.iloc[id]
|
|
|
|
# get the respective data model object
|
|
data_model = self.str_to_model_dict.get(model_str,None)
|
|
assert data_model is not None, f"could not find the requested model_str: {model_str}"
|
|
|
|
# build 'data' and fill the data model object
|
|
data = {**{'id':id}, **series.to_dict()} # 'id' must be added manually, as .to_dict does not contain the index, which was set with .set_index
|
|
data = data_model(**data)
|
|
return data
|
|
|
|
def get_times_for_participant_type(self, df_times, participant_type:int):
|
|
filtered_series = df_times.loc[df_times["participant_type"]==participant_type]
|
|
assert len(filtered_series)<=1, f"found multiple results"
|
|
times = self.df_loc_to_data_model(filtered_series, id=0, model_str='times', loc_type="iloc") # use iloc! to retrieve the first result
|
|
return times
|
|
|
|
def dataframe_to_data_model_list(self, df, model_str)->list:
|
|
model_str = self.standardize_model_str(model_str)
|
|
|
|
all_ids = df.index
|
|
all_data = [
|
|
self.df_loc_to_data_model(df, _aid, model_str)
|
|
for _aid in all_ids
|
|
]
|
|
return all_data
|
|
|
|
def get_participants(self, shipcall_id:id)->list:
|
|
"""
|
|
given a {shipcall_id}, obtain the respective list of participants.
|
|
when there are no participants, return a blank list
|
|
|
|
returns: participant_id_list, where every element is an int
|
|
"""
|
|
df = self.df_dict.get("shipcall_participant_map")
|
|
df = df.set_index('shipcall_id', inplace=False)
|
|
|
|
# the 'if' call is needed to ensure, that no Exception is raised, when the shipcall_id is not present in the df
|
|
participant_id_list = df.loc[shipcall_id, "participant_id"].to_list() if shipcall_id in list(df.index) else []
|
|
return participant_id_list
|
|
|
|
def get_times_of_shipcall(self, shipcall)->pd.DataFrame:
|
|
df_times = self.df_dict.get('times') # -> pd.DataFrame
|
|
df_times = df_times.loc[df_times["shipcall_id"]==shipcall.id]
|
|
return df_times
|
|
|
|
def get_times_for_agency(self, non_null_column=None)->pd.DataFrame:
|
|
"""
|
|
options:
|
|
non_null_column:
|
|
None or str. If provided, the 'non_null_column'-column of the dataframe will be filtered,
|
|
so only entries with provided values are returned (filters all NaN and NaT entries)
|
|
"""
|
|
# get all times
|
|
df_times = self.df_dict.get('times') # -> pd.DataFrame
|
|
|
|
# filter out all NaN and NaT entries
|
|
if non_null_column is not None:
|
|
df_times = df_times.loc[~df_times[non_null_column].isnull()] # NOT null filter
|
|
|
|
# filter by the agency participant_type
|
|
times_agency = df_times.loc[df_times["participant_type"]==ParticipantType.AGENCY.value]
|
|
return times_agency
|
|
|
|
def filter_df_by_key_value(self, df, key, value)->pd.DataFrame:
|
|
return df.loc[df[key]==value]
|
|
|
|
def get_unique_ship_counts(self, all_df_times:pd.DataFrame, query:str, rounding:str="min", maximum_threshold=3):
|
|
"""given a dataframe of all agency times, get all unique ship counts, their values (datetime) and the string tags. returns a tuple (values,unique,counts)"""
|
|
# get values and optional: rounding
|
|
values = all_df_times.loc[:, query]
|
|
if rounding is not None:
|
|
values = values.dt.round(rounding) # e.g., 'min'
|
|
|
|
unique, counts = np.unique(values, return_counts=True)
|
|
violation_state = np.any(np.greater(counts, maximum_threshold))
|
|
return (values, unique, counts)
|