"""
"""
from abc import ABC, abstractmethod
import polars as pl
[docs]
class AggregationStrategy(ABC):
[docs]
@abstractmethod
def apply(self, data:pl.LazyFrame, target_cols:list[str]) -> pl.LazyFrame:
raise NotImplementedError("Subclasses must implement this method")
[docs]
class SplitEqually(AggregationStrategy):
"""
.. image:: ../../images/SplitEqually.svg
"""
def __init__(self, agg_col:str):
"""
Args:
agg_col (str): usually is the boundary, ie: city, town, village, etc.
"""
self.agg_col = agg_col
[docs]
def apply(self, data:pl.LazyFrame, target_cols:list[str]) -> pl.LazyFrame:
"""Provide an example
Args:
data (pl.LazyFrame): _description_
target_cols (list[str]): _description_
agg_col (str): _description_
"""
return (
data
.with_columns([
# first / count over agg_cols(usually is a boundary)
((pl.first(col).over(self.agg_col)) /
(pl.count(col).over(self.agg_col))).alias(col) # overwrite the original column
for col in target_cols
])
.select( # only keep the necessary columns
pl.col('cell'),
pl.col(target_cols)
)
)
[docs]
class Centroid(AggregationStrategy):
"""
.. image:: ../../images/Centroid.svg
"""
[docs]
def apply(self, data:pl.LazyFrame, target_cols:list[str]) -> pl.LazyFrame:
return (
data
.with_columns([
pl.col(col).alias(col)
for col in target_cols
])
.select( # only keep the necessary columns
pl.col('cell'),
pl.col(target_cols)
)
)
[docs]
class SumUp(AggregationStrategy):
"""
.. image:: ../../images/SumUp.svg
"""
[docs]
def apply(self, df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
Scale Up Function
target_cols: list, the columns to be aggregated
"""
# target_cols = [
# target_col for target_col in target_cols if target_col in df.collect_schema().names()]
return (
df
.group_by(
'cell'
)
.agg(
pl.col(target_cols).cast(pl.Float64).sum()
)
)
[docs]
class Mean(AggregationStrategy):
"""
.. image:: ../../images/Mean.svg
"""
[docs]
def apply(self, data: pl.LazyFrame, target_cols: list[str]) -> pl.LazyFrame:
return (
data
.group_by(
'cell'
)
.agg(
pl.col(target_cols).cast(pl.Float64).mean()
)
)
[docs]
class Count(AggregationStrategy):
"""
.. image:: ../../images/Count.svg
"""
def __init__(self, return_percentage: bool = False):
self.return_percentage = return_percentage
[docs]
def apply(self, data:pl.LazyFrame, target_cols:list[str]) -> pl.LazyFrame:
if target_cols == ['hex_id']:
# focus on the h3 index
return (
data
.group_by('cell')
.agg([
pl.count().alias('total_count').cast(pl.Int64),
])
.lazy()
)
elif target_cols:
counts_df = (
data
.group_by(['cell', *target_cols])
.agg([
pl.count().alias(f'{"_".join(target_cols)}_count').cast(pl.Int64),
])
.fill_null('null')
)
pivoted_df = (
counts_df
.collect()
# lazyframe -> dataframe, dataframe is needed for pivot
.pivot(
values = f'{"_".join(target_cols)}_count',
index = 'cell',
on = target_cols
)
.with_columns(
pl.sum_horizontal(pl.exclude('cell')).alias('total_count').cast(pl.Int64)
)
)
# Check if return_percentage is True
if self.return_percentage:
# Calculate percentage for each column
# percentage 只有建立在total_count都一樣的基礎上才有意義
percentage_cols = [
(pl.col(col) / pl.col('total_count') * 100).round(3)
for col in pivoted_df.columns if col != 'cell' and col != 'total_count'
]
return (
pivoted_df
.with_columns(percentage_cols)
# remove total_count if return_percentage is True
.select(pl.exclude('total_count'))
.lazy() # dataframe -> lazyframe
)
else:
# Return counts directly
return pivoted_df.lazy() # dataframe -> lazyframe