Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/Sentiment/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<html>

<head>
<title>ml5 - Sentiment</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/1.6.0/p5.js"></script>
<script src="../../dist/ml5.js"></script>
</head>

<body>
<h1>Sentiment Analysis Demo</h1>
<p>
This example uses model trained on movie reviews. This model scores the sentiment of text with
a value between 0 ("negative") and 1 ("positive"). The movie reviews were truncated to a
maximum of 200 words and only the 20,000 most common words in the reviews are used.
<br></br>
Press 'Enter' on your keyboard or 'Submit' to see score!
</p>

<script src="sketch.js"></script>
</body>

</html>
44 changes: 44 additions & 0 deletions examples/Sentiment/sketch.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
let sentiment;
let statusEl; // to display model loading status
let submitBtn;
let inputBox;
let sentimentResult;

function setup() {
noCanvas();
// initialize sentiment analysis model
sentiment = ml5.sentiment("movieReviews", modelReady);

// setup the html environment
statusEl = createP("Loading Model...");
inputBox = createInput("Today is the happiest day and is full of rainbows!");
inputBox.attribute("size", "75");
submitBtn = createButton("submit");
sentimentResult = createP("Sentiment score:");

// predicting the sentiment on mousePressed()
submitBtn.mousePressed(getSentiment);
}

function getSentiment() {
// get the values from the input
let text = inputBox.value();

// make the prediction
let prediction = sentiment.predict(text);

// display sentiment result on html page
sentimentResult.html("Sentiment score: " + prediction.score);
}

function modelReady() {
// model is ready
statusEl.html("Model loaded");
}

// predicting the sentiment when 'Enter' key is pressed
function keyPressed() {
if (keyCode == ENTER) {
getSentiment();
}
}
135 changes: 135 additions & 0 deletions src/Sentiment/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import * as tf from "@tensorflow/tfjs";
import callCallback from "../utils/callcallback";
import modelLoader from "../utils/modelLoader";

/**
* Initializes the Sentiment demo.
*/

const OOV_CHAR = 2;
const PAD_CHAR = 0;

function padSequences(
sequences,
maxLen,
padding = "pre",
truncating = "pre",
value = PAD_CHAR
) {
return sequences.map((seq) => {
// Perform truncation.
if (seq.length > maxLen) {
if (truncating === "pre") {
seq.splice(0, seq.length - maxLen);
} else {
seq.splice(maxLen, seq.length - maxLen);
}
}
// Perform padding.
if (seq.length < maxLen) {
const pad = [];
for (let i = 0; i < maxLen - seq.length; i += 1) {
pad.push(value);
}
if (padding === "pre") {
// eslint-disable-next-line no-param-reassign
seq = pad.concat(seq);
} else {
// eslint-disable-next-line no-param-reassign
seq = seq.concat(pad);
}
}
return seq;
});
}

class Sentiment {
/**
* Create Sentiment model. Currently the supported model name is 'moviereviews'. ml5 may support different models in the future.
* @param {String} modelName - A string to the path of the JSON model.
* @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise that will be resolved once the model has loaded.
*/
constructor(modelName, callback) {
/**
* Boolean value that specifies if the model has loaded.
* @type {boolean}
* @public
*/
this.ready = callCallback(this.loadModel(modelName), callback);
}

/**
* Initializes the Sentiment demo.
*/

async loadModel(modelName) {
const modelUrl =
modelName.toLowerCase() === "moviereviews"
? "https://storage.googleapis.com/tfjs-models/tfjs/sentiment_cnn_v1/"
: modelName;

const loader = modelLoader(modelUrl, "model");

await tf.setBackend("webgl");

// load in parallel
const [model, sentimentMetadata] = await Promise.all([
loader.loadLayersModel(),
loader.loadMetadataJson(),
]);

/**
* The model being used.
* @type {tf.LayersModel}
* @public
*/
this.model = model;

this.indexFrom = sentimentMetadata.index_from;
this.maxLen = sentimentMetadata.max_len;

this.wordIndex = sentimentMetadata.word_index;
this.vocabularySize = sentimentMetadata.vocabulary_size;

return this;
}

/**
* Scores the sentiment of given text with a value between 0 ("negative") and 1 ("positive").
* @param {String} text - string of text to predict.
* @returns {{score: Number}}
*/
predict(text) {
// Convert to lower case and remove all punctuations.
const inputText = text
.trim()
.toLowerCase()
.replace(/[.,?!]/g, "")
.split(" ");
// Convert the words to a sequence of word indices.

const sequence = inputText.map((word) => {
let wordIndex = this.wordIndex[word] + this.indexFrom;
if (wordIndex > this.vocabularySize) {
wordIndex = OOV_CHAR;
}
return wordIndex;
});

// Perform truncation and padding.
const paddedSequence = padSequences([sequence], this.maxLen);
const input = tf.tensor2d(paddedSequence, [1, this.maxLen]);
const predictOut = this.model.predict(input);
const score = predictOut.dataSync()[0];
predictOut.dispose();
input.dispose();

return {
score,
};
}
}

const sentiment = (modelName, callback) => new Sentiment(modelName, callback);

export default sentiment;
2 changes: 2 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import neuralNetwork from "./NeuralNetwork";
import handpose from "./Handpose";
import sentiment from "./Sentiment";
import facemesh from "./Facemesh";
import poseDetection from "./PoseDetection";
import * as tf from "@tensorflow/tfjs";
Expand All @@ -15,6 +16,7 @@ export default Object.assign(
tfvis,
neuralNetwork,
handpose,
sentiment,
facemesh,
poseDetection,
setBackend,
Expand Down
139 changes: 139 additions & 0 deletions src/utils/modelLoader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import * as tf from "@tensorflow/tfjs";
import axios from "axios";

/**
* Check if the provided URL string starts with a hostname,
* such as http://, https://, etc.
* @param {string} str
* @returns {boolean}
*/
export function isAbsoluteURL(str) {
const pattern = new RegExp("^(?:[a-z]+:)?//", "i");
return pattern.test(str);
}

/**
* Accepts a URL that may be a complete URL, or a relative location.
* Returns an absolute URL based on the current window location.
* @param {string} absoluteOrRelativeUrl
* @returns {string}
*/
export function getModelPath(absoluteOrRelativeUrl) {
if (!isAbsoluteURL(absoluteOrRelativeUrl) && typeof window !== "undefined") {
return window.location.pathname + absoluteOrRelativeUrl;
}
return absoluteOrRelativeUrl;
}

function isKnownName(name) {
return ["model", "manifest", "metadata"].includes(name);
}

/**
* @property {string} directory
* @property {string} modelUrl
* @property {string} manifestUrl
* @property {string} metadataUrl
*/
class ModelLoader {
/**
* Can provide the url to a model.json/metadata.json/manifest.json file,
* or to a folder containing the files.
*
* @param {string} path
* @param {'model'|'manifest'|'metadata'} expected
* @param {boolean} prepend
*/
constructor(path, expected = "model", prepend = true) {
const url = prepend ? getModelPath(path) : path;
const known = {};
// If a specific URL is provided, make sure that we don't overwrite it with generic '/model.json'
// But warn the user and try to correct if it seems like they passed the wrong file type.
if (url.endsWith(".json")) {
const pos = url.lastIndexOf("/") + 1;
this.directory = url.slice(0, pos);
const fileName = url.slice(pos, -5);
if (fileName !== expected && isKnownName(fileName)) {
console.warn(
`Expected a ${expected}.json file URL, but received a ${fileName}.json file instead.`
);
} else {
known[expected] = url;
}
} else {
this.directory = url.endsWith("/") ? url : `${url}/`;
}
this.modelUrl = known.model || this.getPath("model.json");
this.metadataUrl = known.metadata || this.getPath("metadata.json");
this.manifestUrl = known.manifest || this.getPath("manifest.json");
}

/**
* Appends the filename to the base directory.
* @param {string} filename
* @return {string}
*/
getPath(filename) {
return isAbsoluteURL(filename) ? filename : this.directory + filename;
}

/**
* Fetch the JSON data from the manifest file, and throw an error if not found.
* @return {Promise<any>}
*/
async loadManifestJson() {
try {
const res = await axios.get(this.manifestUrl);
return res.data;
} catch (error) {
throw new Error(
`Error loading manifest.json file from URL ${
this.manifestUrl
}: ${String(error)}`
);
}
}

/**
* Fetch the JSON data from the metadata file, and throw an error if not found.
* @return {Promise<any>}
*/
async loadMetadataJson() {
try {
const res = await axios.get(this.metadataUrl);
return res.data;
} catch (error) {
throw new Error(
`Error loading metadata.json file from URL ${
this.metadataUrl
}: ${String(error)}`
);
}
}

/**
* Pass the model URL to the TensorFlow loadLayersModel function.
* If no path is provided, loads file `/model.json` relative to the directory.
* But can also be called with the model url from a manifest file.
* @param {string} [relativePath]
* @return {Promise<tf.LayersModel>}
*/
async loadLayersModel(relativePath) {
const url = relativePath ? this.getPath(relativePath) : this.modelUrl;
try {
return await tf.loadLayersModel(url);
} catch (error) {
throw new Error(`Error loading model from URL ${url}: ${String(error)}`);
}
}
}

/**
* @param {string} path
* @param {'model'|'manifest'|'metadata'} [expected]
* @param {boolean} [prepend]
* @return {ModelLoader}
*/
export default function modelLoader(path, expected, prepend) {
return new ModelLoader(path, expected, prepend);
}