# coding: utf-8
import matplotlib.pyplot as plt
import numpy as np

def plot(I):
        plt.matshow(I)
        plt.show(block=False)
    
def to_pattern(letter):
        return np.array([+1 if c=='X' else -1 for c in letter.replace('\n','')])
		  
A = """
.XXX.
X...X
XXXXX
X...X
X...X
"""
Z = """
XXXXX
...X.
..X..
.X...
XXXXX
"""
plot(to_pattern(A).reshape(5,5))

def train(patterns):
    r,c = patterns.shape
    W = np.zeros((c,c))
    for p in patterns:
        W = W + np.outer(p,p)
    W[np.diag_indices(c)] = 0
    return W/r
	 

