# Packages Required
library(neuralnet)  # ANN
library(onehot)     # One-hot encoding
GitHub Repository

   

Introduction

Artificial neural networks borrow from biology. Like their namesake, artificial neural networks model the neural connections present in our own brains. Nodes in a network are connected to inputs and outputs. Whether or not the output is activated depends on the inputs, the weight on those inputs, and the activation function. Putting many nodes together results in a network that is able to represent any function. What is learned or trained in an ANN are the weights on each input.

   

The Mushroom Data Set

The data set by Schlimmer (1987), retrieved from the UCI machine learning repository (Lichman, 2013), contains 8124 examples with 22 features representing 23 species of gilled mushrooms. The labels are whether the mushroom is edible or poisonous:

  1. cap-shape: bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s
  2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s
  3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r, pink=p,purple=u,red=e,white=w,yellow=y
  4. bruises?: bruises=t,no=f
  5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s
  6. gill-attachment: attached=a,descending=d,free=f,notched=n
  7. gill-spacing: close=c,crowded=w,distant=d
  8. gill-size: broad=b,narrow=n
  9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g, green=r,orange=o,pink=p,purple=u,red=e, white=w,yellow=y
  10. stalk-shape: enlarging=e,tapering=t
  11. stalk-root: bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,missing=?
  12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s
  13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s
  14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y
  15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o, pink=p,red=e,white=w,yellow=y
  16. veil-type: partial=p,universal=u
  17. veil-color: brown=n,orange=o,white=w,yellow=y
  18. ring-number: none=n,one=o,two=t
  19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l, none=n,pendant=p,sheathing=s,zone=z
  20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r, orange=o,purple=u,white=w,yellow=y
  21. population: abundant=a,clustered=c,numerous=n, scattered=s,several=v,solitary=y
  22. habitat: grasses=g,leaves=l,meadows=m,paths=p, urban=u,waste=w,woods=d

   

Exploratory Data Analysis

The purpose of exploring the data first is to get familiar with it and to see if anything is of interest.

shroom <- read.csv("agaricus-lepiota.csv") # Read in the data
str(shroom)
## 'data.frame':    8124 obs. of  23 variables:
##  $ class                   : Factor w/ 2 levels "e","p": 2 1 1 2 1 1 1 1 2 1 ...
##  $ cap.shape               : Factor w/ 6 levels "b","c","f","k",..: 6 6 1 6 6 6 1 1 6 1 ...
##  $ cap.surface             : Factor w/ 4 levels "f","g","s","y": 3 3 3 4 3 4 3 4 4 3 ...
##  $ cap.color               : Factor w/ 10 levels "b","c","e","g",..: 5 10 9 9 4 10 9 9 9 10 ...
##  $ bruises                 : Factor w/ 2 levels "f","t": 2 2 2 2 1 2 2 2 2 2 ...
##  $ odor                    : Factor w/ 9 levels "a","c","f","l",..: 7 1 4 7 6 1 1 4 7 1 ...
##  $ gill.attachment         : Factor w/ 2 levels "a","f": 2 2 2 2 2 2 2 2 2 2 ...
##  $ gill.spacing            : Factor w/ 2 levels "c","w": 1 1 1 1 2 1 1 1 1 1 ...
##  $ gill.size               : Factor w/ 2 levels "b","n": 2 1 1 2 1 1 1 1 2 1 ...
##  $ gill.color              : Factor w/ 12 levels "b","e","g","h",..: 5 5 6 6 5 6 3 6 8 3 ...
##  $ stalk.shape             : Factor w/ 2 levels "e","t": 1 1 1 1 2 1 1 1 1 1 ...
##  $ stalk.root              : Factor w/ 5 levels "?","b","c","e",..: 4 3 3 4 4 3 3 3 4 3 ...
##  $ stalk.surface.above.ring: Factor w/ 4 levels "f","k","s","y": 3 3 3 3 3 3 3 3 3 3 ...
##  $ stalk.surface.below.ring: Factor w/ 4 levels "f","k","s","y": 3 3 3 3 3 3 3 3 3 3 ...
##  $ stalk.color.above.ring  : Factor w/ 9 levels "b","c","e","g",..: 8 8 8 8 8 8 8 8 8 8 ...
##  $ stalk.color.below.ring  : Factor w/ 9 levels "b","c","e","g",..: 8 8 8 8 8 8 8 8 8 8 ...
##  $ veil.type               : Factor w/ 1 level "p": 1 1 1 1 1 1 1 1 1 1 ...
##  $ veil.color              : Factor w/ 4 levels "n","o","w","y": 3 3 3 3 3 3 3 3 3 3 ...
##  $ ring.number             : Factor w/ 3 levels "n","o","t": 2 2 2 2 2 2 2 2 2 2 ...
##  $ ring.type               : Factor w/ 5 levels "e","f","l","n",..: 5 5 5 5 1 5 5 5 5 5 ...
##  $ spore.print.color       : Factor w/ 9 levels "b","h","k","n",..: 3 4 4 3 4 3 3 4 3 3 ...
##  $ population              : Factor w/ 6 levels "a","c","n","s",..: 4 3 3 4 1 3 3 4 5 4 ...
##  $ habitat                 : Factor w/ 7 levels "d","g","l","m",..: 6 2 4 6 2 2 4 4 2 4 ...

 

lapply(shroom, summary)
## $class
##    e    p 
## 4208 3916 
## 
## $cap.shape
##    b    c    f    k    s    x 
##  452    4 3152  828   32 3656 
## 
## $cap.surface
##    f    g    s    y 
## 2320    4 2556 3244 
## 
## $cap.color
##    b    c    e    g    n    p    r    u    w    y 
##  168   44 1500 1840 2284  144   16   16 1040 1072 
## 
## $bruises
##    f    t 
## 4748 3376 
## 
## $odor
##    a    c    f    l    m    n    p    s    y 
##  400  192 2160  400   36 3528  256  576  576 
## 
## $gill.attachment
##    a    f 
##  210 7914 
## 
## $gill.spacing
##    c    w 
## 6812 1312 
## 
## $gill.size
##    b    n 
## 5612 2512 
## 
## $gill.color
##    b    e    g    h    k    n    o    p    r    u    w    y 
## 1728   96  752  732  408 1048   64 1492   24  492 1202   86 
## 
## $stalk.shape
##    e    t 
## 3516 4608 
## 
## $stalk.root
##    ?    b    c    e    r 
## 2480 3776  556 1120  192 
## 
## $stalk.surface.above.ring
##    f    k    s    y 
##  552 2372 5176   24 
## 
## $stalk.surface.below.ring
##    f    k    s    y 
##  600 2304 4936  284 
## 
## $stalk.color.above.ring
##    b    c    e    g    n    o    p    w    y 
##  432   36   96  576  448  192 1872 4464    8 
## 
## $stalk.color.below.ring
##    b    c    e    g    n    o    p    w    y 
##  432   36   96  576  512  192 1872 4384   24 
## 
## $veil.type
##    p 
## 8124 
## 
## $veil.color
##    n    o    w    y 
##   96   96 7924    8 
## 
## $ring.number
##    n    o    t 
##   36 7488  600 
## 
## $ring.type
##    e    f    l    n    p 
## 2776   48 1296   36 3968 
## 
## $spore.print.color
##    b    h    k    n    o    r    u    w    y 
##   48 1632 1872 1968   48   72   48 2388   48 
## 
## $population
##    a    c    n    s    v    y 
##  384  340  400 1248 4040 1712 
## 
## $habitat
##    d    g    l    m    p    u    w 
## 3148 2148  832  292 1144  368  192

 

Immediately noticed is that veil type only has one factor and stock root is missing in 2480 examples. Other than that the data looks good.

   

Data Preprocessing

Drop Veil Type

The veil type is a constant, so drop it from the data set.

shroom <- subset(x = shroom, select = -veil.type) # Drop veil.type column

 

Handle Missing Values

In this case, stalk root contains many missing values. There are many different ways to handle this type of data, in this example the missing data is turned into a factor called “u” for unknown.

table(shroom$stalk.root) # Inspect
## 
##    ?    b    c    e    r 
## 2480 3776  556 1120  192

 

levels(shroom$stalk.root) <- c("u", "b", "c", "e", "r") # Convert
table(shroom$stalk.root) # Inspect
## 
##    u    b    c    e    r 
## 2480 3776  556 1120  192

   

The Artificial Neural Network

Create Numeric Data

The ANN algorithm requires all data to be numeric, including the outputs (labels).

shroom_numeric <- as.data.frame(lapply(shroom, as.numeric)) # Change all rows to numeric
str(shroom_numeric)
## 'data.frame':    8124 obs. of  22 variables:
##  $ class                   : num  2 1 1 2 1 1 1 1 2 1 ...
##  $ cap.shape               : num  6 6 1 6 6 6 1 1 6 1 ...
##  $ cap.surface             : num  3 3 3 4 3 4 3 4 4 3 ...
##  $ cap.color               : num  5 10 9 9 4 10 9 9 9 10 ...
##  $ bruises                 : num  2 2 2 2 1 2 2 2 2 2 ...
##  $ odor                    : num  7 1 4 7 6 1 1 4 7 1 ...
##  $ gill.attachment         : num  2 2 2 2 2 2 2 2 2 2 ...
##  $ gill.spacing            : num  1 1 1 1 2 1 1 1 1 1 ...
##  $ gill.size               : num  2 1 1 2 1 1 1 1 2 1 ...
##  $ gill.color              : num  5 5 6 6 5 6 3 6 8 3 ...
##  $ stalk.shape             : num  1 1 1 1 2 1 1 1 1 1 ...
##  $ stalk.root              : num  4 3 3 4 4 3 3 3 4 3 ...
##  $ stalk.surface.above.ring: num  3 3 3 3 3 3 3 3 3 3 ...
##  $ stalk.surface.below.ring: num  3 3 3 3 3 3 3 3 3 3 ...
##  $ stalk.color.above.ring  : num  8 8 8 8 8 8 8 8 8 8 ...
##  $ stalk.color.below.ring  : num  8 8 8 8 8 8 8 8 8 8 ...
##  $ veil.color              : num  3 3 3 3 3 3 3 3 3 3 ...
##  $ ring.number             : num  2 2 2 2 2 2 2 2 2 2 ...
##  $ ring.type               : num  5 5 5 5 1 5 5 5 5 5 ...
##  $ spore.print.color       : num  3 4 4 3 4 3 3 4 3 3 ...
##  $ population              : num  4 3 3 4 1 3 3 4 5 4 ...
##  $ habitat                 : num  6 2 4 6 2 2 4 4 2 4 ...

 

Split the Data Between Testing and Training

set.seed(77)                                                          # Get the same data each time
idx <- sample(nrow(shroom_numeric), round(nrow(shroom_numeric)*0.7))  # Create 2 subsets with ratio 70:30
shroom_train <- shroom_numeric[idx, ]                                 # Training subset
shroom_test <- shroom_numeric[-idx, ]                                 # Testing subset

 

Train the Model

shroom_model <- neuralnet(class ~ cap.shape +
                                  cap.surface +
                                  cap.color +
                                  bruises +
                                  odor +
                                  gill.attachment +
                                  gill.spacing +
                                  gill.size +
                                  gill.color +
                                  stalk.shape +
                                  stalk.root +
                                  stalk.surface.above.ring +
                                  stalk.surface.below.ring +
                                  stalk.color.above.ring +
                                  stalk.color.below.ring +
                                  veil.color +
                                  ring.number +
                                  ring.type +
                                  spore.print.color +
                                  population +
                                  habitat,
                                  data = shroom_train)

 

plot(shroom_model, rep="best")
Figure 1: Simple Neural Network Architecture

Figure 1: Simple Neural Network Architecture

 

Test the Model

shroom_pred <- compute(shroom_model, shroom_test[,2:ncol(shroom_test)])

 

Evaluating Performance

The output of this ANN is not binary, so round the outputs to get the binary classification.

tail(shroom_pred$net.result) # Either 'close' to 1 or 'close' to 2, not binary
##             [,1]
## 8113 1.080281568
## 8114 1.986631160
## 8117 1.986949223
## 8120 1.080284446
## 8121 1.080281548
## 8123 1.986922528

 

# Returns the percentage of correct predictions
# Rounds the predictions first (non binary)
get.accuracy <- function(prediction, real) {
  prediction <- round(prediction) 
  accuracy <- prediction == real
  return (length(accuracy[accuracy == TRUE])/length(accuracy))
}

 

get.accuracy(shroom_pred$net.result, shroom_test$class)
## [1] 0.9417316373

 

Tweaking Hyperparameters

There are many neural network hyperparameters or ‘knobs’ to try tweaking to get better results.

 

Adjusting the Network by Adding Hidden Nodes

# Returns the accuracy of a neural net with N nodes
add.hidden.nodes <- function(hidden_nodes, train, test, test_cl) {
  # Aggregate results
  n <- c()
  a <- c()

  for (nodes in hidden_nodes) {
    # Train
    model <- neuralnet(class ~ cap.shape + 
                               cap.surface + 
                               cap.color + 
                               bruises + 
                               odor + 
                               gill.attachment + 
                               gill.spacing + 
                               gill.size + 
                               gill.color + 
                               stalk.shape + 
                               stalk.root + 
                               stalk.surface.above.ring + 
                               stalk.surface.below.ring + 
                               stalk.color.above.ring + 
                               stalk.color.below.ring + 
                               veil.color + ring.number +
                               ring.type + spore.print.color + 
                               population + 
                               habitat, 
                               data = train, 
                               hidden = nodes)
    
    # Test
    prediction <- compute(model, test)
    
    # Evaluate
    accuracy <- get.accuracy(prediction$net.result, test_cl)
    
    # Aggregate results
    n <- c(n, nodes)
    a <- c(a, accuracy)
  }

  return (as.data.frame(list("Nodes" = n, "Accuracy" = a)))
}

 

add.hidden.nodes(c(0, 2, 4, 6), shroom_train, shroom_test[,2:ncol(shroom_test)], shroom_test$class)
##   Nodes     Accuracy
## 1     0 0.9437833402
## 2     2 0.9954862536
## 3     4 0.9934345507
## 4     6 1.0000000000

 

Adjusting the Network by Adding a Hidden Layer

# Train an ANN with 2 hidden layers
# The first layer with 4 nodes
# The second layer with 2 nodes
shroom_model_layered <- neuralnet(class ~ cap.shape +
                                  cap.surface + 
                                  cap.color + 
                                  bruises + 
                                  odor + 
                                  gill.attachment +
                                  gill.spacing +
                                  gill.size +
                                  gill.color +
                                  stalk.shape + 
                                  stalk.root + 
                                  stalk.surface.above.ring + 
                                  stalk.surface.below.ring + 
                                  stalk.color.above.ring +
                                  stalk.color.below.ring + 
                                  veil.color + ring.number + 
                                  ring.type + spore.print.color + 
                                  population + 
                                  habitat, 
                                  data = shroom_train, 
                                  hidden = c(4, 2))

 

plot(shroom_model_layered, rep="best")