Skip to content

pyomyo for classification

PerlinWarp edited this page Oct 12, 2021 · 2 revisions

Finger Classification with simple_classifier

pyomyo includes a simple_classifier script that allows for live labelling and prediction of incoming data.
Running python simple_classifier.py will launch a window that listens for incoming data and keypresses.
Pressing a numerical key from 0-9 will label incoming data as that class.

Labelling data for finger classification

In the above gif, I am labelling the movement of each finger as a number from 1 to 5 using 0 as resting.
I make a gesture with my right hand and label it with my left hand.
Note that consistent placement and rotation of the Myo on the skin is key to getting good data, for more info see the placement section of the wiki for more detail.

In the gif, you can see the bars change, they represent the current prediction of the incoming data.
Next to the numerical name of the class is the number of examples we have for each class, ideally, we would gather the same amount for each class. Lazy classifiers will minimise their error by predicting the most common class so getting a similar number of examples helps disincentive an AI from using this strategy.

The gathered data is as a numpy array for each class in the data directory, we open these and merge them into one pandas dataframe below.

# Set the path of your labelled data
path = "./data/"

# Load the relevant class data
rest  = np.fromfile(path+"vals0.dat", dtype=np.uint16).reshape((-1, 8))
thumb  = np.fromfile(path+"vals1.dat", dtype=np.uint16).reshape((-1, 8))
index  = np.fromfile(path+"vals2.dat", dtype=np.uint16).reshape((-1, 8))
middle = np.fromfile(path+"vals3.dat", dtype=np.uint16).reshape((-1, 8))
ring   = np.fromfile(path+"vals4.dat", dtype=np.uint16).reshape((-1, 8))
pinky  = np.fromfile(path+"vals5.dat", dtype=np.uint16).reshape((-1, 8))
# Make sure to change this if you use a different labelling convention
data = [rest, thumb, index, middle, ring, pinky]
label_cols = ['rest', 'thumb', 'index', 'middle', 'ring', 'pinky']

# Combine simple_classifiers numpy files into one csv
df = pd.DataFrame(np.vstack(data))