
# using data from
# https://www.gapminder.org/data/
# downloaded data file: net_users_num.csv

# https://docs.python.org/3/library/csv.html

import csv
import matplotlib.pyplot as plt

class Util:
  #static fields of a class
  unitNames   = [ 'k', 'K', 'm', 'M', 'g', 'G', 't', 'T', 'p', 'P' ]
  unitValues  = { 'k': 10**3,   'K': 10**3,  \
                  'm': 10**6,   'M': 10**6,  \
                  'g': 10**9,   'G': 10**9,  \
                  't': 10**12,  'T': 10**12, \
                  'p': 10**15,  'P': 10**15  }

  @staticmethod
  # returns specific value (None, 0, etc)  when the string is not decoded 
  def strValue( str ):
    specValue = 0
    
    if str in [None, ""]: return specValue
    str = str.strip()  # remowe leading/trailing whitespaces, tabs
    # no unit 
    if str[-1].isdigit():
      try:
        value = int( str )
        return value
      except:
        return specValue
    # unit specified
    else: 
      unitChar = str[-1]
      # unit unknown
      if not unitChar in Util.unitNames:
        return specValue
      # unit known
      try:
        value  =  float( str[:-1] ) 
        return  round( value * Util.unitValues[unitChar] )
      except:
        return specValue

  # return field names (=column names) and subsequent rows
  # read from a given .csv file
  @staticmethod
  def read_fields_rows( filename ):
    fieldNames = []  # field names are on the first line
    rows = []
    with open(filename, 'r') as csvfile:
      # creating a csv reader object
      csvreader = csv.reader(csvfile)
      # extracting field names through first row
      fieldNames = next(csvreader)
      # extracting each data row one by one
      for row in csvreader:
        rows.append(row)
      print( filename + "  read.")
      print( "Number of data rows : %d" % (len(rows)) )
      # -- with
    return fieldNames, rows

# -------------------------------------------------------------------
#    p r o c e s s
# -------------------------------------------------------------------

# extract rows corresponding to given country names
# convert row values from original text format to usual numbers (floats)
def extractRows( selectedNames, rows ):
  newRows = []
  for row in rows:
    rowName = row[0]
    if rowName in selectedNames:
      newRowValues = [ Util.strValue( s ) for s in row[1:] ]
      newRows.append( [rowName] + newRowValues )
  return newRows

# -------------------------------------------------------------------
#    M  A  I  N
# -------------------------------------------------------------------

filename = "net_users_num.csv"
fieldNames, rows = Util.read_fields_rows( filename )

print( 'Field names are:\n' + ', '.join(field for field in fieldNames) )
rowNames = [ row[0] for row in rows ]
print( 'Row names are:\n' + ', '.join(rowNames) )

selectedNames = ['Andorra', 'Mongolia', 'Peru', 'Portugal', 'Czech Republic' ]
selectedRows = extractRows( selectedNames, rows )
print( selectedRows )

# create x and y data series for each selected row
# to be plotted by matplotlib
# plot all series in one picture

# all data values to display are in columns 1,2, ...
x_years = [ Util.strValue( s ) for s in fieldNames[1:] ]
y_values = [ row[1:] for row in selectedRows ]

# create scatter plot for the values:
plt.title("web users in countries/years")
plt.xlabel("x-year")
plt.ylabel("y-users")
# separate data series for each row (=country)
for i in range( len(selectedRows) ):
  plt.scatter( x_years, y_values[i], label=selectedNames[i] )
  #plt.plot(x_years, y_values[i], label=selectedNames[i])
plt.legend()
plt.show()
















