Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit b69fb02

Browse files
authored
src: allow operators to access graph in initialization (#227)
1 parent a183fa6 commit b69fb02

File tree

9 files changed

+37
-16
lines changed

9 files changed

+37
-16
lines changed

lib/backend.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ export interface SessionHandler {
3232

3333
/**
3434
* Resolves the operator from the name and opset version; backend specific
35-
* @param node
36-
* @param opsets
35+
* @param node the node to resolve
36+
* @param opsets a list of opsets that exported from the model
37+
* @param graph the completely initialized graph
3738
*/
39+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator;
3840

39-
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator;
4041
/**
4142
* This method let's the sessionHandler know that the graph initialization is complete
4243
* @param graph the completely initialized graph
4344
*/
44-
4545
onGraphInitialized?(graph: Graph): void;
4646

4747
/**

lib/backends/cpu/session-handler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ export class CpuSessionHandler implements SessionHandler {
1919

2020
dispose(): void {}
2121

22-
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
22+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator {
2323
const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES);
24-
op.initialize(node.attributes);
24+
op.initialize(node.attributes, node, graph);
2525
return op;
2626
}
2727
}

lib/backends/wasm/session-handler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ export class WasmSessionHandler implements SessionHandler {
2323

2424
dispose(): void {}
2525

26-
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
26+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator {
2727
const op = resolveOperator(node, opsets, this.opResolveRules);
28-
op.initialize(node.attributes);
28+
op.initialize(node.attributes, node, graph);
2929
return op;
3030
}
3131
}

lib/backends/webgl/session-handler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ export class WebGLSessionHandler implements SessionHandler {
5656
this.textureDataCache.forEach(td => this.textureManager.releaseTexture(td, true));
5757
this.textureDataCache = new Map();
5858
}
59-
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
59+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator {
6060
const op = resolveOperator(node, opsets, WEBGL_OP_RESOLVE_RULES);
61-
op.initialize(node.attributes);
61+
op.initialize(node.attributes, node, graph);
6262
return op;
6363
}
6464
}

lib/operators.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
import {Attribute} from './attribute';
55
import {InferenceHandler} from './backend';
6+
import {Graph} from './graph';
67
import {Tensor} from './tensor';
78

89
export interface Operator {
9-
initialize(attributes: Attribute): void;
10+
initialize(attributes: Attribute, node: Graph.Node, graph: Graph): void;
1011
checkInputs(inputs: Tensor[]): boolean;
1112
run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
1213
}

lib/session.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ export class Session {
232232
this._ops = new Array(nodes.length);
233233

234234
for (let i = 0; i < nodes.length; i++) {
235-
this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets);
235+
this._ops[i] = this.sessionHandler.resolve(nodes[i], this._model.opsets, graph);
236236
}
237237
}
238238

test/test-runner.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {Logger, Profiler} from '../lib/instrument';
1616
import {Operator} from '../lib/operators';
1717
import {Tensor} from '../lib/tensor';
1818

19-
import {base64toBuffer} from './test-shared';
19+
import {base64toBuffer, createMockGraph} from './test-shared';
2020
import {Test} from './test-types';
2121

2222
// the threshold that used to compare 2 float numbers. See above for TensorResultValidator.floatEqual().
@@ -301,7 +301,8 @@ function initializeOperator(
301301
opsetImports: ReadonlyArray<Test.OperatorTestOpsetImport>): Operator {
302302
const attributes = new Attribute(undefined);
303303
attributeValues.forEach(value => attributes.set(value.name, value.type, value.data));
304-
return sessionHandler.resolve({name: '', opType, inputs: [], outputs: [], attributes}, opsetImports);
304+
const graph = createMockGraph(opType, attributes);
305+
return sessionHandler.resolve(graph.getNodes()[0], opsetImports, graph);
305306
}
306307

307308
/**

test/test-shared.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import * as fs from 'fs';
55
import {promisify} from 'util';
66

7+
import {Attribute} from '../lib/attribute';
8+
import {Graph} from '../lib/graph';
9+
710
export function base64toBuffer(data: string): Uint8Array {
811
return Buffer.from(data, 'base64');
912
}
@@ -28,3 +31,18 @@ async function readFile(file: string) {
2831
return Buffer.from(buffer);
2932
}
3033
}
34+
35+
/**
36+
* create a single-node graph for unit test purpose
37+
*/
38+
export function createMockGraph(opType: string, attributes: Attribute): Graph {
39+
const node: Graph.Node = {name: '', opType, inputs: [], outputs: [], attributes};
40+
return {
41+
getInputIndices: () => [],
42+
getInputNames: () => [],
43+
getOutputIndices: () => [],
44+
getOutputNames: () => [],
45+
getNodes: () => [node],
46+
getValues: () => []
47+
};
48+
}

test/unittests/backends/webgl/test_conv_new.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {CpuConv} from '../../../../lib/backends/cpu/ops/conv';
88
import {Profiler} from '../../../../lib/instrument';
99
import {Tensor} from '../../../../lib/tensor';
1010
import {TensorResultValidator} from '../../../test-runner';
11+
import {createMockGraph} from '../../../test-shared';
1112

1213
const validator = new TensorResultValidator('webgl');
1314
let webglBackend: Backend|undefined;
@@ -28,8 +29,8 @@ function webglConv(
2829
attributes.set('pads', 'ints', pads);
2930
}
3031
attributes.set('strides', 'ints', strides);
31-
const op = webglSessionhandler!.resolve(
32-
{opType: 'Conv', attributes, inputs: [], outputs: [], name: `Conv`}, [{domain: '', version: 7}]);
32+
const graph = createMockGraph('Conv', attributes);
33+
const op = webglSessionhandler!.resolve(graph.getNodes()[0], [{domain: '', version: 7}], graph);
3334
if (!op.checkInputs([inputTensor, kernelTensor])) {
3435
throw new Error('Invalid inputs');
3536
}

0 commit comments

Comments
 (0)