# using data from
# https://www.gapminder.org/data/
# downloaded data file: net_users_num.csv, population_total.csv

# https://docs.python.org/3/library/csv.html

import csv
import matplotlib.pyplot as plt

class Util:
  #static fields of a class
  unitNames   = [ 'k', 'K', 'm', 'M', 'g', 'G', 't', 'T', 'p', 'P' ]
  unitValues  = { 'k': 10**3,   'K': 10**3,  \
                  'm': 10**6,   'M': 10**6,  \
                  'g': 10**9,   'G': 10**9,  \
                  't': 10**12,  'T': 10**12, \
                  'p': 10**15,  'P': 10**15  }

  @staticmethod
  # returns specific value (None, 0, etc)  when the string is not decoded 
  def strValue( str ):
    specValue = 0
    
    if str in [None, ""]: return specValue
    str = str.strip()  # remowe leading/trailing whitespaces, tabs
    # no unit 
    if str[-1].isdigit():
      try:
        value = int( str )
        return value
      except:
        return specValue
    # unit specified
    else: 
      unitChar = str[-1]
      # unit unknown
      if not unitChar in Util.unitNames:
        return specValue
      # unit known
      try:
        value  =  float( str[:-1] ) 
        return  round( value * Util.unitValues[unitChar] )
      except:
        return specValue

  # return field names (=column names) and subsequent rows
  # read from a given .csv file
  @staticmethod
  def read_fields_rows( filename ):
    fieldNames = []  # field names are on the first line
    rows = []
    with open(filename, 'r') as csvfile:
      # creating a csv reader object
      csvreader = csv.reader(csvfile)
      # extracting field names through first row
      fieldNames = next(csvreader)
      # extracting each data row one by one
      for row in csvreader:
        rows.append(row)
      print( filename + "  read.")
      print( "Number of data rows : %d" % (len(rows)) )
      # -- with
    return fieldNames, rows

  # reduce csv data (field names and rows) to selected fields only
  @staticmethod
  def reduce_to_fields( fieldNames, rows, selFieldNames ):
    newFieldNames = []  # field names are on the first line
    newRows = []
    # very classical approach, one by one :
    newFieldNames.append( fieldNames[0] ) # retain original column[0] header
    for name in fieldNames:
      if name in selFieldNames: newFieldNames.append( name )
    for row in rows:
      newRow = [ row[0] ] # retain original row[0] (= row name)
      for j in range(len(fieldNames)) :
        fieldName = fieldNames[j]
        if fieldName in selFieldNames:
          newRow.append( row[j] )
      newRows.append( newRow )

    return newFieldNames, newRows


# -------------------------------------------------------------------
#    p r o c e s s
# -------------------------------------------------------------------

# extract rows corresponding to given country names
# convert row values from original text format to usual numbers (floats)
def extractRows( selectedNames, rows ):
  newRows = []
  for row in rows:
    rowName = row[0]
    if rowName in selectedNames:
      newRowValues = [ Util.strValue( s ) for s in row[1:] ]
      newRows.append( [rowName] + newRowValues )
  return newRows

# -------------------------------------------------------------------
#    M  A  I  N
# -------------------------------------------------------------------
# NU .. Net Users, PT .. Population Total

# read data and reduce data to selected columns (=years)
filename = "net_users_num.csv"
fieldNamesNU, rowsNU = Util.read_fields_rows( filename )
filename = "population_total.csv"
fieldNamesPT, rowsPT = Util.read_fields_rows( filename )

selFieldNames = [str(k) for k in range(2000, 2011)]

fieldNamesNU, rowsNU = Util.reduce_to_fields( fieldNamesNU, rowsNU, selFieldNames )
fieldNamesPT, rowsPT = Util.reduce_to_fields( fieldNamesPT, rowsPT, selFieldNames )

selectedNames = ['Andorra', 'Mongolia', 'Peru', 'Portugal', 'Czech Republic' ]
selectedRowsNU = extractRows( selectedNames, rowsNU )
selectedRowsPT = extractRows( selectedNames, rowsPT )

# create new data - ratio of    net users / population
selectedRowsNUPT = []
for i in range( len( selectedRowsNU ) ):
  rowNU, rowPT = selectedRowsNU[i], selectedRowsPT[i]
  rowNUPT = [ rowNU[0] ]  # retain row name
  for j in range( 1, len(rowNU) ):
    ratio = 0
    # avoid division by 0
    if rowPT[j] > 0: ratio = rowNU[j] / rowPT[j]
    rowNUPT.append( ratio )
  selectedRowsNUPT.append( rowNUPT )

print()
print( selectedRowsNU )
print()
print( selectedRowsPT )
print()
print( selectedRowsNUPT )


# create x and y data series for each selected row
# to be plotted by matplotlib
# plot all series in one picture

# all data values to display are in columns 1,2, ...
x_years = [ Util.strValue( s ) for s in fieldNamesNU[1:] ]
y_values = [ row[1:] for row in selectedRowsNUPT ]

# create scatter plot for the values:
plt.title("web users in countries/years \n relative to population total in that country")
plt.xlabel("x-year")
plt.ylabel("y-users")
# separate data series for each row (=country)
for i in range( len(selectedRowsNUPT) ):
  #plt.scatter( x_years, y_values[i], label=selectedNames[i] )
  plt.plot(x_years, y_values[i], label=selectedNames[i], linewidth=3)
plt.legend()
plt.show()
















