# start konzoly:
# export PYSPARK_PYTHON=python3
# pyspark --master yarn --num-executors 2 --executor-memory 4G --conf spark.ui.port=10811

from pyspark.sql import functions as F

# vychozi data
Tr = spark.sql("select * from pascepet.transakce2 and country='cze'") \ 
    .cache()
# podle zadani mela byt podminka cze=True, ale kvuli bugu Hive se udela omezeni primo na vychozi sloupec country 
# tabulka transakce2 byla vytvorena v Hive - predpoklada se, ze ma omezeni na datum podle varianty B

# 1. Kolik unikátních klientů je v datech?
cl_cnt = Tr.select('clid').distinct().count()
print(cl_cnt)
# Výsledek: 89 822

# 2. Jaký podíl klientů (v %) má aspoň jednu transakci v sobotu časně ráno (s číslem hodiny od 5 do 7)?
cl_sobota_rano_cnt = Tr.filter(F.dayofweek(Tr['day'])==7) \
    .filter((Tr['hour']>=5) & (Tr['hour']<=7)) \
    .select('clid').distinct().count()
print(cl_sobota_rano_cnt / cl_cnt * 100)
# Výsledek: 6.88 %

# 3. Jaký podíl klientů (v %) nemá žádnou ze svých transakcí na oblíbené stanici?
cl_fav_cnt = Tr.filter(Tr['fav']==1).select('clid').distinct().count() # klienti s aspon jednou transakci na oblibene stanici 
cl_unfav_cnt = cl_cnt - cl_fav_cnt # zbytek klientu nema transakci na oblibene stanici
print(cl_unfav_cnt / cl_cnt * 100)
# Výsledek: 47.11 % 

# 4. Jaký je nejvyšší počet různých oblíbených stanic pro jednoho klienta?
# + Zjistěte pro klienta s nejvyšším počtem různých oblíbených stanic průměr GPS souřadnic pro tyto stanice (každou stanici počítejte jen jednou).
# zaznam s nejvyssim poctem unikatnich oblibenych stanic
Tr2 = Tr.filter(Tr['fav']==1) \
    .select('clid', 'posid').distinct() \
    .groupBy('clid').count() \
    .toDF('clid', 'pocet_fav') \
    .orderBy('pocet_fav', ascending=False) \
    .limit(5)
Tr2.show()
# Výsledek: nejvyšší počet je 18
clid_favmax = Tr2.limit(1).select('clid').rdd.collect()[0]['clid'] # 14366121 (uznalo se i pouziti 14356835, také má 18 obl. stanic)  
# transakce vybraneho klienta na oblibenych stanicich, kazdou stanici vzit jen jednou
Tr.filter((Tr['clid']==clid_favmax) & (Tr['fav']==1)) \
    .select('clid', 'posid', 'pos_gps_lat', 'pos_gps_lon').distinct() \
    .groupBy().agg({'pos_gps_lat': 'avg', 'pos_gps_lon': 'avg'}) \
    .show()
# Výsledek:
# +-----------------+------------------+
# | avg(pos_gps_lat)|  avg(pos_gps_lon)|
# +-----------------+------------------+
# |50.75171767340766|15.003921084933811|
# +-----------------+------------------+

# 5. Rozdělíme transakce podle vzdálenosti stanice od bydliště do pásem po 5 km. Vyloučíme záznamy s vyšší vzdáleností než 150 km nebo s neuvedenou vzdáleností.
# + Jaký počet transakcí byl proveden v jednotlivých pásmech a ve kterém pásmu to bylo nejvíc?
# + Liší se to pro stanice v západní části ČR (pos_gps_lon <= 14) a pro stanice ve východní části ČR (pos_gps_lon >= 17)?
TrZ = Tr.filter(Tr['pos_gps_lon'] <= 14.0) # pomocne DataFrames jen pro zapadni a vychodni cast CR
TrV = Tr.filter(Tr['pos_gps_lon'] >= 17.0)
# pasma po 5 km vzdalenosti, ponechany jen zaznamy se znamou vzdalenosti do 150 km
Tr2 = Tr.withColumn('dist5', F.floor(Tr['dist'] / 5.0)).select('dist', 'dist5').dropna().filter('dist <= 150.0')
TrZ2 = TrZ.withColumn('dist5', F.floor(TrZ['dist'] / 5.0)).select('dist', 'dist5').dropna().filter('dist <= 150.0')
TrV2 = TrV.withColumn('dist5', F.floor(TrV['dist'] / 5.0)).select('dist', 'dist5').dropna().filter('dist <= 150.0')
# cetnosti v pasmech
Tr2.groupBy('dist5').count() \
    .toDF('dist5', 'pocet')  \
    .orderBy('pocet', ascending=False) \
    .show()
# Výsledek:
# +-----+------+
# |dist5| pocet|
# +-----+------+
# |    0|268314|
# |    1|132000|
# |    2| 71173|
# |    3| 39693|
# |    4| 25751|
# |    5| 17933|
# ...
# jen západní část ČR
TrZ2.groupBy('dist5').count() \
    .toDF('dist5', 'pocet')  \
    .orderBy('pocet', ascending=False) \
    .show()
# Výsledek:
# +-----+-----+
# |dist5|pocet|
# +-----+-----+
# |    0|37296|
# |    1|14761|
# |    2| 7915|
# |    3| 5031|
# |    4| 3117|
# |    5| 1919|
# ...
# jen východní část ČR
TrV2.groupBy('dist5').count() \
    .toDF('dist5', 'pocet')  \
    .orderBy('pocet', ascending=False) \
    .show()
# Výsledek:
# +-----+-----+
# |dist5|pocet|
# +-----+-----+
# |    0|67193|
# |    1|31388|
# |    2|15476|
# |    3| 8464|
# |    4| 5590|
# |    5| 3197|
# ...

# 6. Kterých osm stanic má nejvyšší počty plateb a kterých osm nejvyšší počty unikátních klientů?
Tr.groupBy('posid').count() \
    .toDF('posid', 'pocet') \
    .orderBy('pocet', ascending=False) \
    .show(8)
# Výsledek:
# +-----+-----+
# |posid|pocet|
# +-----+-----+
# | 6971| 3143|
# | 4585| 3047|
# | 7293| 2575|
# | 5217| 2356|
# | 7329| 2298|
# | 4531| 2164|
# | 6140| 2036|
# | 5685| 2025|
# +-----+-----+
Tr.select('clid', 'posid').distinct() \
    .groupBy('posid').count() \
    .toDF('posid', 'pocet') \
    .orderBy('pocet', ascending=False) \
    .show(8)
# Výsledek:
# +-----+-----+
# |posid|pocet|
# +-----+-----+
# | 6971| 1162|
# | 4585| 1159|
# | 7293| 1046|
# | 7379|  951|
# | 7329|  924|
# | 5090|  916|
# | 7442|  855|
# | 4531|  855|
# +-----+-----+

# 7. Zjistěte průměrnou zaplacenou částku na stanicích Shell a na stanicích Eurooil (názvy jsou součástí textu ve sloupci name), porovnejte průměry mezi sebou.
# v datech se udelaji sloupce s priznakem pro Shell a pro Eurooil
Tr2 = Tr.withColumn('flag_shell', Tr['name'].contains('shell')) \
    .withColumn('flag_eurooil', Tr['name'].contains('eurooil'))
# lze pouzit i obecnejsi detekci souladu s regularnim vyrazem, ale zde je zbytecne slozite:
# Tr2 = Tr.withColumn('flag_shell', F.when(F.regexp_extract('name', r'(shell)', 1)=="", False).otherwise(True)) \
#    .withColumn('flag_eurooil', F.when(F.regexp_extract('name', r'(eurooil)', 1)=="", False).otherwise(True))
# vypocet prumerne ceny
Tr2.filter(Tr2['flag_shell']) \
    .groupBy().avg('amt') \
    .show()
Tr2.filter(Tr2['flag_eurooil']) \
    .groupBy().avg('amt') \
    .show()
# Výsledek:
# +-----------------+
# |         avg(amt)|
# +-----------------+
# |507.9691738517474|
# +-----------------+
# +-----------------+
# |         avg(amt)|
# +-----------------+
# |529.7077961167021|
# +-----------------+
# Závěr: průměrná útrata u Eurooilu je asi o 22 k4 vyšší než u Shellu.

# 8. Jaká jsou nejčastější slova v názvech stanic (každou stanici počítejte jen jednou)? Liší se to pro stanice v západní a ve východní části ČR (viz úlohu 5)? 
# DataFrame se upravi na nazvy unikatnich stanic a konvertuje do RDD
# Vyuzijeme pomocne DataFrames z ulohy 5
Nazvy = Tr.select('posid', 'name').distinct() \
    .select('name').rdd 
NazvyZ = TrZ.select('posid', 'name').distinct() \
    .select('name').rdd
NazvyV = TrV.select('posid', 'name').distinct() \
    .select('name').rdd
# klasicky WordCount
Nazvy.flatMap(lambda x: x[0].split(' ')) \
    .map(lambda x: (x, 1)) \
    .reduceByKey(lambda a, b: a+b) \
    .sortBy(lambda r: -r[1]) \
    .take(5)
# Výsledek:
# [('cs', 1026), ('benzina', 385), ('cerpaci', 302), ('oil', 284), ('stanice', 277)]
NazvyZ.flatMap(lambda x: x[0].split(' ')) \
    .map(lambda x: (x, 1)) \
    .reduceByKey(lambda a, b: a+b) \
    .sortBy(lambda r: -r[1]) \
    .take(5)
# Výsledek:
# [('cs', 195), ('benzina', 83), ('oil', 61), ('stanice', 39), ('cerpaci', 39)]
NazvyV.flatMap(lambda x: x[0].split(' ')) \
    .map(lambda x: (x, 1)) \
    .reduceByKey(lambda a, b: a+b) \
    .sortBy(lambda r: -r[1]) \
    .take(5)
# Výsledek:
# [('cs', 224), ('benzina', 71), ('cerpaci', 57), ('stanice', 51), ('s', 50)]

# 9. Spočítejte Jaccardův koeficient podobnosti (počet prvků průniku děleno počtem prvků sjednocení) mezi množinami klientů s aspoň jednou platbou na stanicích s id 8776 a 8777.
# vsichni klienti stanice jako list
TrA = Tr.filter('posid=8776') \
    .select('clid').distinct() \
    .rdd.map(lambda x: x[0]) \
    .collect()
TrB = Tr.filter('posid=8777') \
    .select('clid').distinct() \
    .rdd.map(lambda x: x[0]) \
    .collect()
TrA = set(TrA) # vyhozeni duplicitnich prvku
TrB = set(TrB)
# Jaccarduv koeficient
print(len(TrA.intersection(TrB)) / len(TrA.union(TrB)))
# Výsledek: 0.16
