From 5e9929f706d8ebf13aeffe58ad2169f1a2946894 Mon Sep 17 00:00:00 2001 From: Silei Date: Fri, 16 Apr 2021 13:56:18 -0700 Subject: [PATCH 01/32] Auto-annotate: avoid punctuation & pronouns in annotations --- lib/pos-parser/index.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib/pos-parser/index.ts b/lib/pos-parser/index.ts index caeb908b1..7cc57b5ac 100644 --- a/lib/pos-parser/index.ts +++ b/lib/pos-parser/index.ts @@ -112,6 +112,16 @@ export default class PosParser { for (const template of this.queryTemplates[pos]) { const match = template.match(utterance, domainCanonicals, value); if (match && !match.includes('$domain') && match.split(' ').length - 1 < MAX_LENGTH) { + // FIXME: capture these in templates + // skip matches containing punctuations that always introduce a break in the utterance + if (/[,.!?:]/.test(match)) + continue; + // skip reverse property that contains a pronoun + if (pos === 'reverse_property') { + const tokens = match.split(' '); + if (tokens.includes('it') || tokens.includes('that') || tokens.includes('this')) + continue; + } if (pos === 'verb' && match.startsWith('$value ')) { return [ { pos, canonical: match }, From f38389ccee31422d4a2432abae4905e232571cd6 Mon Sep 17 00:00:00 2001 From: Silei Date: Sun, 25 Apr 2021 23:47:53 -0700 Subject: [PATCH 02/32] Cache wikidata end point queries --- tool/autoqa/wikidata/utils.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tool/autoqa/wikidata/utils.js b/tool/autoqa/wikidata/utils.js index b31be6f8e..22e985258 100644 --- a/tool/autoqa/wikidata/utils.js +++ b/tool/autoqa/wikidata/utils.js @@ -22,6 +22,8 @@ import * as Tp from 'thingpedia'; const URL = 'https://query.wikidata.org/sparql'; +const _cache = new Map(); + const WikidataUnitToTTUnit = { // time 'millisecond': 'ms', @@ -144,11 +146,15 @@ function unitConverter(wikidataUnit) { * @returns {Promise<*>} */ async function wikidataQuery(query) { + if (_cache.has(query)) + return _cache.get(query); try { const result = await Tp.Helpers.Http.get(`${URL}?query=${encodeURIComponent(query)}`, { accept: 'application/json' }); - return JSON.parse(result).results.bindings; + const parsed = JSON.parse(result).results.bindings; + _cache.set(query, parsed); + return parsed; } catch(e) { const error = new Error('The connection timed out waiting for a response'); error.code = 500; From 174e11be80b6248546ba03f4b2247f189dd5ec76 Mon Sep 17 00:00:00 2001 From: Silei Xu Date: Mon, 28 Jun 2021 23:39:36 -0700 Subject: [PATCH 03/32] Fix a bug in the pos tagger --- lib/i18n/english.ts | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/i18n/english.ts b/lib/i18n/english.ts index e123bb51e..01457fd7a 100644 --- a/lib/i18n/english.ts +++ b/lib/i18n/english.ts @@ -22,6 +22,7 @@ import { Inflectors } from 'en-inflectors'; import { Tag } from 'en-pos'; +import * as lexicon from 'en-lexicon'; import { coin } from '../utils/random'; import { Phrase } from '../utils/template-string'; @@ -135,6 +136,13 @@ function indefiniteArticle(word : string) { export default class EnglishLanguagePack extends DefaultLanguagePack { protected _tokenizer : EnglishTokenizer|undefined; + constructor(locale : string) { + super(locale); + + // the pos tagger will crash without this lexicon extension + lexicon.extend({ constructor: 'NN' }); + } + getTokenizer() : EnglishTokenizer { if (this._tokenizer) return this._tokenizer; From 55452d8cca771cb983efdfda87b31cee5e5968ed Mon Sep 17 00:00:00 2001 From: Silei Xu Date: Tue, 29 Jun 2021 23:41:13 -0700 Subject: [PATCH 04/32] Bug fix in pos-parser: avoid inherit from Object.prototype Similar to the bug discovered in the pos tagger, when the key is "constructor", an obejct created from literal syntax will return a function instead of undefined when 'constructor' key is missing --- lib/pos-parser/nfa.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/pos-parser/nfa.ts b/lib/pos-parser/nfa.ts index 7756d3500..d7be2b836 100644 --- a/lib/pos-parser/nfa.ts +++ b/lib/pos-parser/nfa.ts @@ -42,7 +42,7 @@ class State { constructor(isEnd = false) { this.id = stateCounter++; this.isEnd = isEnd; - this.transitions = {}; + this.transitions = Object.create(null); } addTransition(token : string, to : State, capturing = false) { From ca72d2b25dca2c748b1c668f02f4865600089edf Mon Sep 17 00:00:00 2001 From: Silei Xu Date: Mon, 21 Jun 2021 17:28:00 -0700 Subject: [PATCH 05/32] Add a hack to include common_name as human type for CSQA --- lib/templates/utils.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/templates/utils.ts b/lib/templates/utils.ts index 730f03e09..d478770e8 100644 --- a/lib/templates/utils.ts +++ b/lib/templates/utils.ts @@ -308,6 +308,8 @@ function isHumanEntity(type : Type|string) : boolean { return false; if (['tt:contact', 'tt:username', 'org.wikidata:human'].includes(type)) return true; + if (type === 'org.wikidata:common_name') // hack for CSQA dataset + return true; if (type.startsWith('org.schema') && type.endsWith(':Person')) return true; return false; From 4d1351a081ff8c73bb7518f5233b1cf7dedba9bb Mon Sep 17 00:00:00 2001 From: Silei Xu Date: Fri, 8 Oct 2021 11:05:56 -0700 Subject: [PATCH 06/32] Refactor Wikidata script leveraging bootleg and CSQA in TypeScript Data source: - Use CSQA preprocessed files as a source of data, only query wikidata service end point when information is missing in CSQA dumps - Use Bootleg data as a reference when determine the type of an entity Type system: In total, 3 difference option is provided (1) entity-plain: one entity type per property based on property name (2) entity-hierarchical: one entity type for each value, and the property type is the supertype of all types of its values; the property type has a prefix `p_` (3) string: everthing string except id When entity-heirarchical is selected, ThingTalk: - Add the option to include Entity value (QID) in ThingTalk Other tricks: - Remove properties that share the same label as the domain: In wikidata, each country has a property country, poiting to itself, and it's confusing - Remove trailing QIDs in bootleg type canonical: When there are multiple types share the same canonical, bootleg will append QID at the end of the canonical for all types except one. Sometimes, it also append QID to some type that does not share canonical with other types, such as "nation", and "designation", not sure why. In our case, the type information is to assist the parsing, the actual QID is not important, so we drop all the appended QIDs - if two types share the same canonical, they are considered as the same type in natural language. --- lib/utils/stream-utils.ts | 9 +- lib/utils/thingtalk/syntax.ts | 4 +- .../lib/canonical-example-paraphraser.ts | 8 + tool/autoqa/wikidata/csqa-converter.ts | 823 ++++++++++++++++++ tool/autoqa/wikidata/csqa-type-mapper.ts | 317 +++++++ tool/autoqa/wikidata/demo.ts | 10 +- tool/autoqa/wikidata/label-retriever.js | 67 -- tool/autoqa/wikidata/make-string-datasets.js | 309 ------- tool/autoqa/wikidata/manual-annotations.js | 128 --- tool/autoqa/wikidata/manual-annotations.ts | 32 + tool/autoqa/wikidata/preprocess-bootleg.ts | 74 ++ .../wikidata/preprocess-knowledge-base.ts | 529 +++++++++++ tool/autoqa/wikidata/process-schema.ts | 293 +++++++ tool/autoqa/wikidata/{utils.js => utils.ts} | 261 +++++- tool/genie.ts | 7 +- 15 files changed, 2314 insertions(+), 557 deletions(-) create mode 100644 tool/autoqa/wikidata/csqa-converter.ts create mode 100644 tool/autoqa/wikidata/csqa-type-mapper.ts delete mode 100644 tool/autoqa/wikidata/label-retriever.js delete mode 100644 tool/autoqa/wikidata/make-string-datasets.js delete mode 100644 tool/autoqa/wikidata/manual-annotations.js create mode 100644 tool/autoqa/wikidata/manual-annotations.ts create mode 100644 tool/autoqa/wikidata/preprocess-bootleg.ts create mode 100644 tool/autoqa/wikidata/preprocess-knowledge-base.ts create mode 100644 tool/autoqa/wikidata/process-schema.ts rename tool/autoqa/wikidata/{utils.js => utils.ts} (50%) diff --git a/lib/utils/stream-utils.ts b/lib/utils/stream-utils.ts index 781793f20..7bf70c07c 100644 --- a/lib/utils/stream-utils.ts +++ b/lib/utils/stream-utils.ts @@ -258,9 +258,16 @@ export { CountStream, }; -export function waitFinish(stream : Stream.Writable) : Promise { +export function waitFinish(stream : NodeJS.WritableStream) : Promise { return new Promise((resolve, reject) => { stream.once('finish', resolve); stream.on('error', reject); }); } + +export function waitEnd(stream : NodeJS.ReadableStream) : Promise { + return new Promise((resolve, reject) => { + stream.once('end', resolve); + stream.on('error', reject); + }); +} \ No newline at end of file diff --git a/lib/utils/thingtalk/syntax.ts b/lib/utils/thingtalk/syntax.ts index f0e8e694d..2e7291399 100644 --- a/lib/utils/thingtalk/syntax.ts +++ b/lib/utils/thingtalk/syntax.ts @@ -140,6 +140,7 @@ interface SerializeOptions { timezone : string|undefined; ignoreSentence ?: boolean; compatibility ?: string; + includeEntityValue ?: boolean; } /** @@ -158,6 +159,7 @@ export function serializePrediction(program : Ast.Input, ignoreSentence: options.ignoreSentence || false, }); return Syntax.serialize(program, Syntax.SyntaxType.Tokenized, entityRetriever, { - compatibility: options.compatibility + compatibility: options.compatibility, + includeEntityValue: options.includeEntityValue }); } diff --git a/tool/autoqa/lib/canonical-example-paraphraser.ts b/tool/autoqa/lib/canonical-example-paraphraser.ts index 20c2c6ef5..9272ed58e 100644 --- a/tool/autoqa/lib/canonical-example-paraphraser.ts +++ b/tool/autoqa/lib/canonical-example-paraphraser.ts @@ -45,6 +45,14 @@ export default class Paraphraser { if (process.env.CI || process.env.TRAVIS) return; + // output paraphrase input + if (this.options.debug) { + const output = util.promisify(fs.writeFile); + await output('./paraphraser-in.json', JSON.stringify(examples.map((e) => { + return { utterance: e.utterance, arg: e.argument, value: e.value ?? null }; + }), null, 2)); + } + // call genienlp to run paraphrase const args = [ `run-paraphrase`, diff --git a/tool/autoqa/wikidata/csqa-converter.ts b/tool/autoqa/wikidata/csqa-converter.ts new file mode 100644 index 000000000..e953bb657 --- /dev/null +++ b/tool/autoqa/wikidata/csqa-converter.ts @@ -0,0 +1,823 @@ +// -*- mode: typescript; indent-tabs-mode: nil; js-basic-offset: 4 -*- +// +// This file is part of Genie +// +// Copyright 2019-2021 The Board of Trustees of the Leland Stanford Junior University +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Naoki Yamamura +// Silei Xu + +import * as argparse from 'argparse'; +import * as fs from 'fs'; +import assert from 'assert'; +import * as util from 'util'; +import * as path from 'path'; +import * as ThingTalk from 'thingtalk'; +import { Ast, Type } from 'thingtalk'; +import * as I18N from '../../../lib/i18n'; +import { serializePrediction } from '../../../lib/utils/thingtalk'; +import { getElementType, getItemLabel, argnameFromLabel, readJson, Domains } from './utils'; +import { makeDummyEntities } from "../../../lib/utils/entity-utils"; + +async function loadClassDef(thingpedia : string, options : ThingTalk.Syntax.ParseOptions) { + const schema = await util.promisify(fs.readFile)(thingpedia, { encoding: 'utf8' }); + const library = ThingTalk.Syntax.parse(schema, ThingTalk.Syntax.SyntaxType.Normal, options); + assert(library instanceof Ast.Library && library.classes.length === 1); + return library.classes[0]; +} + +export interface CSQADialogueTurn { + speaker : 'USER'|'SYSTEM', + utterance : string, + ques_type_id : 1|2|3|4|5|6|7|8, + sec_ques_type ?: 1|2, + sec_ques_sub_type ?: 1|2|3|4, + is_inc ?: 0|1, + is_incomplete ?: 0|1, + bool_ques_type ?: 1|2|3|4|5|6, + inc_ques_type ?: 1|2|3, + set_op_choice ?: 1|2|3, + set_op ?: 1|2, + count_ques_sub_type ?: 1|2|3|4|5|6|7|8|9, + type_list ?: string[], + entities_in_utterance ?: string[], + active_set ?: string[] +} + +interface CSQADialogueTurnPair { + file : string, + system : CSQADialogueTurn, + user : CSQADialogueTurn, +} + +interface ParameterRecord { + value : string, + preprocessed : string +} + +interface CSQAConverterOptions { + locale : string; + timezone ?: string; + domains : Domains, + includeEntityValue : boolean, + filter : string, + softMatchId : boolean, + inputDir : string, + output : string, + thingpedia : string, + wikidataProperties : string, + items : string, + values : string, + types : string, + filteredExamples : string +} + +class CsqaConverter { + private _locale : string; + private _timezone ?: string; + private _domains : Domains; + private _includeEntityValue : boolean; + private _softMatchId : boolean; + private _filters : Record; + private _paths : Record; + private _classDef : Ast.ClassDef|null; + private _items : Map>; + private _values : Map; + private _types : Map; + private _wikidataProperties : Map; + private _examples : CSQADialogueTurnPair[]; + private _tokenizer : I18N.BaseTokenizer; + private _unsupportedCounter : Record; + + constructor(options : CSQAConverterOptions) { + this._locale = options.locale; + this._timezone = options.timezone; + this._domains = options.domains; + this._includeEntityValue = options.includeEntityValue; + this._softMatchId = options.softMatchId; + this._filters = {}; + for (const filter of options.filter || []) { + assert(filter.indexOf('=') > 0 && filter.indexOf('=') === filter.lastIndexOf('=')); + const [key, values] = filter.split('='); + this._filters[key] = values.split(',').map((v) => parseInt(v)); + } + + this._paths = { + inputDir: options.inputDir, + output: options.output, + thingpedia: options.thingpedia, + wikidataProperties: options.wikidataProperties, + items: options.items, + values: options.values, + types: options.types, + filteredExamples: options.filteredExamples + }; + + this._classDef = null; + this._items = new Map(); + this._values = new Map(); + this._types = new Map(); + + this._wikidataProperties = new Map(); + + this._examples = []; + this._tokenizer = I18N.get('en-US').getTokenizer(); + + this._unsupportedCounter = { + indirect: 0, + setOp: 0, + typeConstraint: 0, + wrongAnnotation: 0 + }; + } + + private async _getArgValue(qid : string) : Promise { + let value; + if (this._values.has(qid)) { + value = this._values.get(qid); + } else { + value = await getItemLabel(qid); + if (value) + this._values.set(qid, value); + } + if (value) + return { value: qid, preprocessed: this._tokenizer.tokenize(value).tokens.join(' ') }; + throw new Error(`Label not found for ${qid}`); + } + + private _invocationTable(domain : string) : Ast.Expression { + const selector = new Ast.DeviceSelector(null, 'org.wikidata', null, null); + return new Ast.InvocationExpression(null, new Ast.Invocation(null, selector, domain, [], null), null); + } + + private _generateFilter(domain : string, param : string, value : ParameterRecord) : Ast.BooleanExpression { + let ttValue, op; + if (param === 'id') { + if (this._softMatchId) { + ttValue = new Ast.Value.String(value.preprocessed); + op = '=~'; + } else { + ttValue = new Ast.Value.Entity(value.value, `org.wikidata:${domain}`, value.preprocessed); + op = '=='; + } + } else { + const propertyType = this._classDef!.getFunction('query', domain)!.getArgType(param)!; + const entityType = this._types.get(value.value); + const valueType = entityType ? new Type.Entity(`org.wikidata:${entityType}`) : getElementType(propertyType); + if (valueType instanceof Type.Entity) { + ttValue = new Ast.Value.Entity(value.value, valueType.type, value.preprocessed); + op = propertyType.isArray ? 'contains' : '=='; + } else { // Default to string + ttValue = new Ast.Value.String(value.preprocessed); + op = propertyType.isArray ? 'contains~' : '=~'; + } + } + return new Ast.BooleanExpression.Atom(null, param, op, ttValue); + } + + private _getDomainBySubject(x : string) : string|null { + if (x.startsWith('c')) + return this._domains.getDomainByCSQAType(x.slice(1)); + for (const [domain, items] of this._items) { + if (x in items) + return domain; + } + return null; + } + + // returns [domain, projection, filter] + private async _processOneActiveSet(activeSet : string[][]) : Promise<[string, string[]|Ast.BooleanExpression|null, Ast.BooleanExpression]> { + const triple = activeSet[0]; + const domain = this._getDomainBySubject(triple[0]); + assert(domain); + const subject = triple[0].startsWith('c') ? null : await this._getArgValue(triple[0]); + const relation = await argnameFromLabel(this._wikidataProperties.get(triple[1])!); + const object = triple[2].startsWith('c') ? null : await this._getArgValue(triple[2]); + + // when object is absent, return a projection on relation with filtering on id = subject + if (subject && !object) + return [domain, [relation], this._generateFilter(domain, 'id', subject)]; + // when subject is absent, return a filter on the relation with the object value + if (!subject && object) + return [domain, null, this._generateFilter(domain, relation, object)]; + // when both subject and object exists, then it's a verification question + // return a boolean expression as projection, and a filter on id = subject + if (subject && object) + return [domain, this._generateFilter(domain, relation, object), this._generateFilter(domain, 'id', subject)]; + + throw new Error('Both subject and object absent in the active set entry: ' + activeSet); + } + + // returns [domain, projection, filter] + private async _processOneActiveSetWithSetOp(activeSet : string[][], setOp : number) : Promise<[string, string[]|null, Ast.BooleanExpression|null]> { + assert(activeSet.length === 1 && activeSet[0].length > 3 && activeSet[0].length % 3 === 0); + const triples = []; + for (let i = 0; i < activeSet[0].length; i += 3) + triples.push(activeSet[0].slice(i, i + 3)); + // when the subjects of some triples are different, it requires set operation + // in ThingTalk to represent, which is not supported yet + const subjects = new Set(triples.map(((triple) => triple[0]))); + if (subjects.size > 1) + return ['unknown', null, null]; + + // process tripes in active set + const domains = []; + const projections = []; + const filters = []; + for (let i = 0; i < triples.length; i ++) { + const [domain, projection, filter] = await this._processOneActiveSet(triples.slice(i, i+1)); + domains.push(domain); + projections.push(projection as string[]); // it won't be boolean question for set ops + filters.push(filter); + } + + // FIXME: we current don't handle multiple domains + assert((new Set(domains)).size === 1); + const domain = domains[0]; + + // when projection is not null, it means we should have the same id filter on + // both triple, and different projection + if (projections[0] && projections[0].length > 0) { + const uniqueProjections = [...new Set(projections.flat())]; + return [domain, uniqueProjections, filters[0]]; + } + // when projection is null, then we merge two filters according to setOp + switch (setOp) { + case 1: return [domain, null, new Ast.BooleanExpression.Or(null, filters)]; // OR + case 2: return [domain, null, new Ast.BooleanExpression.And(null, filters)]; // AND + case 3: { // DIFF + assert(filters.length === 2); + const negateFilter = new Ast.BooleanExpression.Not(null, filters[1]); + return [domain, null, new Ast.BooleanExpression.And(null, [filters[0], negateFilter])]; + } + default: + throw new Error(`Unknown set_op_choice: ${setOp}`); + } + } + + // ques_type_id=1 + private async _simpleQuestion(activeSet : string[][]) : Promise { + assert(activeSet.length === 1); + const [domain, projection, filter] = await this._processOneActiveSet(activeSet); + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain), filter, null); + if (projection && Array.isArray(projection) && projection.length > 0) + return new Ast.ProjectionExpression(null, filterTable, projection, [], [], null); + return filterTable; + } + + // ques_type_id=2 + private async _secondaryQuestion(activeSet : string[][], secQuesType : number, secQuesSubType : number) : Promise { + if (secQuesSubType === 2 || secQuesSubType === 3) { + this._unsupportedCounter.indirect += 1; + return null; + } + if (secQuesSubType === 1) { + if (activeSet.length !== 1) { + this._unsupportedCounter.wrongAnnotation += 1; + return null; + } + return this._simpleQuestion(activeSet); + } + if (secQuesSubType === 4) { + // this it basically is asking multiple questions in one sentence. + // it is sometimes ambiguous with set-based questions + if (activeSet.length <= 1) + throw new Error('Only one active set found for secondary plural question'); + const domains = []; + const projections = []; + const filters = []; + for (let i = 0; i < activeSet.length; i ++) { + const [domain, projection, filter] = await this._processOneActiveSet(activeSet.slice(i, i+1)); + domains.push(domain); + projections.push(projection as string[]); + filters.push(filter); + } + + // FIXME: we current don't handle multiple domains + assert((new Set(domains)).size === 1); + const domain = domains[0]; + + const filter = new Ast.BooleanExpression.Or(null, filters); + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain), filter, null); + // when subjects of triples are entity, we are asking the same projection for multiple entities + if (secQuesType === 1) { + const uniqueProjection = [...new Set(projections.flat())]; + assert(uniqueProjection.length === 1); + return new Ast.ProjectionExpression(null, filterTable, uniqueProjection, [], [], null); + } + // when subjects of triples are type (domain), we are asking multiple questions, each of which + // satisfies a different filter + if (secQuesType === 2) + return filterTable; + throw new Error('Invalid sec_ques_type for secondary question'); + } + throw new Error('Invalid sec_sub_ques_type for secondary question'); + } + + // ques_type_id=4 + private async _setBasedQuestion(activeSet : string[][], setOpChoice : number) : Promise { + assert(activeSet.length === 1); + const [domain, projection, filter] = await this._processOneActiveSetWithSetOp(activeSet, setOpChoice); + if (!projection && !filter) { + this._unsupportedCounter.setOp += 1; + return null; + } + + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain!), filter!, null); + if (projection && projection.length > 0) + return new Ast.ProjectionExpression(null, filterTable, projection, [], [], null); + return filterTable; + } + + + // ques_type_id=5 + private async _booleanQuestion(activeSet : string[][], boolQuesType : number) : Promise { + if (boolQuesType === 1) { + assert(activeSet.length === 1); + const [domain, projection, filter] = await this._processOneActiveSet(activeSet); + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain), filter, null); + return new Ast.BooleanQuestionExpression(null, filterTable, projection as Ast.BooleanExpression, null); + } + if (boolQuesType === 4) { + assert(activeSet.length === 2); + const [domain1, projection1, filter] = await this._processOneActiveSet(activeSet); + const [domain2, projection2, ] = await this._processOneActiveSet(activeSet.slice(1)); + // FIXME: we current don't handle multiple domains + assert(domain1 === domain2); + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain1), filter, null); + const projection = new Ast.BooleanExpression.And(null, [projection1, projection2]); + return new Ast.BooleanQuestionExpression(null, filterTable, projection, null); + } + // indirect questions + this._unsupportedCounter.indirect += 1; + return null; + } + + // ques_type_id=7 + private async _quantitativeQuestionsSingleEntity(activeSet : string[][], entities : string[], countQuesSubType : number, utterance : string) : Promise { + switch (countQuesSubType) { + case 1: { // Quantitative (count) + assert(activeSet.length === 1); + const [domain, projection, filter] = await this._processOneActiveSet(activeSet); + return this._quantitativeQuestionCount(domain, projection as string[], filter); + } + case 2: // Quantitative (min/max) + return this._quantitativeQuestionMinMax(activeSet, utterance); + case 3: // Quantitative (atleast/atmost/~~/==) + return this._quantitativeQuestionCompareCount(activeSet, utterance); + case 4: // Comparative (more/less/~~) + return this._comparativeQuestion(activeSet, entities, utterance); + case 5: { // Quantitative (count over atleast/atmost/~~/==) + const filterTable = await this._quantitativeQuestionCompareCount(activeSet, utterance); + return new Ast.AggregationExpression(null, filterTable, '*', 'count', null); + } + case 6: { // Comparative (count over more/less/~~) + const filterTable = await this._comparativeQuestion(activeSet, entities, utterance); + return new Ast.AggregationExpression(null, filterTable, '*', 'count', null); + } + case 7: + case 8: + case 9: + // indirect questions + this._unsupportedCounter.indirect += 1; + return null; + default: + throw new Error(`Unknown count_ques_sub_type: ${countQuesSubType}`); + } + } + + // ques_type_id=8 + private async _quantitativeQuestionsMultiEntity(activeSet : string[][], entities : string[], countQuesSubType : number, setOpChoice : number, utterance : string) : Promise { + // Somehow set op is reverse of question type 4, there is no diff set op in this category + const setOp = setOpChoice === 2 ? 1:2; + switch (countQuesSubType) { + case 1: { // Quantitative with logical operators + assert(activeSet.length === 2); + activeSet = [activeSet[0].concat(activeSet[1])]; + const [domain, projection, filter] = await this._processOneActiveSetWithSetOp(activeSet, setOp); + if (!projection && !filter) { + this._unsupportedCounter.setOp += 1; + return null; + } + return this._quantitativeQuestionCount(domain, projection as string[], filter!); + } + case 2: // Quantitative (count) + case 3: // Quantitative (min/max) + case 4: // Quantitative (atleast/atmost/~~/==) + case 5: // Comparative (more/less/~~) + case 6: // Quantitative (count over atleast/atmost/~~/==) + case 7: // Comparative (count over more/less/~~) + this._unsupportedCounter.typeConstraint += 1; + return null; + case 8: + case 9: + case 10: + // indirect questions + this._unsupportedCounter.indirect += 1; + return null; + default: + throw new Error(`Unknown count_ques_sub_type: ${countQuesSubType}`); + } + } + + private _quantitativeOperator(utterance : string) : string { + // there is literally only one single way to talk about most aggregation + // operators in CSQA, so it's easy to decide + if (utterance.includes(' min ')) + return 'asc'; + if (utterance.includes(' max ')) + return 'desc'; + if (utterance.includes(' atleast' )) + return '>='; + if (utterance.includes(' atmost ')) + return '<='; + if (utterance.includes(' exactly ')) + return '=='; + if (utterance.includes(' approximately ') || utterance.includes(' around ')) + return '~~'; + throw new Error('Failed to identify quantitative operator based on the utterance'); + } + + private _comparativeOperator(utterance : string) : string { + if (utterance.includes(' more ') || utterance.includes(' greater number ')) + return '>='; + if (utterance.includes(' less ') || utterance.includes(' lesser number ')) + return '<='; + if (utterance.includes(' same number ')) + return '~~'; + + throw new Error('Failed to identify comparative operator based on the utterance'); + } + + private _numberInUtterance(utterance : string) : number { + // we expect exactly one number in the utterance + const matches = utterance.match(/\d+/); + if (!matches || matches.length === 0) + throw new Error('Failed to locate numbers from the utterance'); + if (matches.length > 1) + throw new Error('Multiple numbers found in the utterance'); + return parseInt(matches[0]); + } + + // Quantitative (count) + private async _quantitativeQuestionCount(domain : string, projection : string[], filter : Ast.BooleanExpression) : Promise { + const filterTable = new Ast.FilterExpression(null, this._invocationTable(domain), filter, null); + // when projection exists, it is counting parameter on a table with id filter + if (projection) { + const computation = new Ast.Value.Computation( + 'count', + projection.map((param) => new Ast.Value.VarRef(param)) + ); + return new Ast.ProjectionExpression(null, filterTable, [], [computation], [null], null); + } + // when projection is absent, it is counting a table with a regular filter + return new Ast.AggregationExpression(null, filterTable, '*', 'count', null); + } + + // Quantitative (min/max) + private async _quantitativeQuestionMinMax(activeSet : string[][], utterance : string) : Promise { + assert(activeSet.length === 1); + const triple = activeSet[0]; + if (!triple[0].startsWith('c') || !triple[2].startsWith('c')) { + this._unsupportedCounter.wrongAnnotation += 1; + return null; + } + const propertyLabel = this._wikidataProperties.get(triple[1]); + assert(propertyLabel); + const param = await argnameFromLabel(propertyLabel); + const computation = new Ast.Value.Computation( + 'count', + [new Ast.Value.VarRef(param)] + ); + const domain = this._getDomainBySubject(triple[0]); + assert(domain); + const countTable = new Ast.ProjectionExpression(null, this._invocationTable(domain), [], [computation], [null], null); + const direction = this._quantitativeOperator(utterance); + const sortTable = new Ast.SortExpression(null, countTable, new Ast.Value.VarRef('count'), direction as "asc"|"desc", null); + return new Ast.IndexExpression(null, sortTable, [new Ast.Value.Number(1)], null); + } + + // Quantitative (atleast/atmost/~~/==) + private async _quantitativeQuestionCompareCount(activeSet : string[][], utterance : string) : Promise { + assert(activeSet.length === 1); + const triple = activeSet[0]; + assert(triple[0].startsWith('c') && triple[2].startsWith('c')); + const propertyLabel = this._wikidataProperties.get(triple[1]); + assert(propertyLabel); + const param = await argnameFromLabel(propertyLabel); + const computation = new Ast.Value.Computation( + 'count', + [new Ast.Value.VarRef(param)] + ); + const filter = new Ast.BooleanExpression.Compute( + null, + computation, + this._quantitativeOperator(utterance), + new Ast.Value.Number(this._numberInUtterance(utterance)), + null + ); + const domain = this._getDomainBySubject(triple[0]); + assert(domain); + return new Ast.FilterExpression(null, this._invocationTable(domain), filter, null); + } + + // comparative (more/less/~~) + private async _comparativeQuestion(activeSet : string[][], entities : string[], utterance : string) : Promise { + assert(activeSet.length === 1 && entities.length === 1); + const triple = activeSet[0]; + assert(triple[0].startsWith('c') && triple[2].startsWith('c')); + const domain = this._getDomainBySubject(triple[0]); + const propertyLabel = this._wikidataProperties.get(triple[1]); + assert(domain && propertyLabel); + const param = await argnameFromLabel(propertyLabel); + const comparisonTarget = await this._getArgValue(entities[0]); + const filter = this._generateFilter(domain, 'id', comparisonTarget); + const subquery = new Ast.ProjectionExpression( + null, + new Ast.FilterExpression(null, this._invocationTable(domain), filter, null), + [], + [new Ast.Value.Computation('count', [new Ast.Value.VarRef(param)])], + [null], + null + ); + return new Ast.FilterExpression( + null, + this._invocationTable(domain), + new Ast.ComparisonSubqueryBooleanExpression( + null, + new Ast.Value.Computation('count', [new Ast.Value.VarRef(param)]), + this._comparativeOperator(utterance), + subquery, + null + ), + null + ); + } + + async csqaToThingTalk(dialog : CSQADialogueTurnPair) : Promise { + const user = dialog.user; + const system = dialog.system; + + if (user.is_incomplete || user.is_inc) { + this._unsupportedCounter.indirect += 1; + return null; + } + + const activeSet = []; + assert(system.active_set); + for (const active of system.active_set) + activeSet.push(active.replace(/[^0-9PQc,|]/g, '').split(',')); + + switch (user.ques_type_id) { + case 1: // Simple Question (subject-based) + return this._simpleQuestion(activeSet); + case 2: // Secondary question + return this._secondaryQuestion(activeSet, user.sec_ques_type!, user.sec_ques_sub_type!); + case 3: // Clarification (for secondary) question + this._unsupportedCounter.indirect += 1; + return null; + case 4: // Set-based question + return this._setBasedQuestion(activeSet, user.set_op_choice!); + case 5: // Boolean (Factual Verification) question + return this._booleanQuestion(activeSet, user.bool_ques_type!); + case 6: // Incomplete question (for secondary) + this._unsupportedCounter.indirect += 1; + return null; + case 7: // Comparative and Quantitative questions (involving single entity) + return this._quantitativeQuestionsSingleEntity(activeSet, user.entities_in_utterance!, user.count_ques_sub_type!, user.utterance); + case 8: // Comparative and Quantitative questions (involving multiple(2) entities) + return this._quantitativeQuestionsMultiEntity(activeSet, user.entities_in_utterance!, user.count_ques_sub_type!, user.set_op!, user.utterance); + default: + throw new Error(`Unknown ques_type_id: ${user.ques_type_id}`); + } + } + + private async _filterTurnsByDomain(dialog : CSQADialogueTurn[], file : string) { + let userTurn; + for (const turn of dialog) { + const speaker = turn.speaker; + if (speaker === 'USER') { + let skip = false; + for (const [key, values] of Object.entries(this._filters)) { + if (!values.includes(turn[key as keyof CSQADialogueTurn] as number)) + skip = true; + } + userTurn = skip ? null : turn; + } else { + if (!userTurn) + continue; + assert(turn.active_set); + + // only consider examples that contain _only_ the given domain + let inDomain = true; + for (const active of turn.active_set) { + const triples = active.replace(/[^0-9PQc,|]/g, '').split(','); + for (let i = 0; i < triples.length; i += 3) { + const subject = triples[i]; + const domain = this._getDomainBySubject(subject); + if (!domain && !this._items.has(subject)) + inDomain = false; + } + } + if (inDomain) { + this._examples.push({ + file: file, + user: userTurn, + system: turn, + }); + } + } + } + } + + async _filterExamples() { + for (const dir of fs.readdirSync(this._paths.inputDir)) { + for (const file of fs.readdirSync(path.join(this._paths.inputDir, dir))) { + const dialog = JSON.parse(fs.readFileSync(path.join(this._paths.inputDir, dir, file), { encoding: 'utf-8'})); + this._filterTurnsByDomain(dialog, file); + } + } + console.log(`${this._examples.length} QA pairs found`); + await util.promisify(fs.writeFile)(this._paths.filteredExamples, JSON.stringify(this._examples, undefined, 2)); + } + + async _loadFilteredExamples() { + this._examples = JSON.parse(await util.promisify(fs.readFile)(this._paths.filteredExamples, { encoding: 'utf-8'})); + } + + async _convert() { + const annotated = []; + const skipped = []; + const error = []; + for (const example of this._examples) { + let expression; + try { + expression = await this.csqaToThingTalk(example); + } catch(e) { + console.log('Error during conversion:'); + console.log('question:', example.user.utterance); + console.log('triples:', example.system.active_set); + console.error(e.message); + expression = null; + } + + if (!expression) { + skipped.push(example); + continue; + } + + try { + const program = new Ast.Program(null, [], [], [new Ast.ExpressionStatement(null, expression)]); + const user = example.user; + const preprocessed = this._tokenizer.tokenize(user.utterance).tokens.join(' '); + const entities = makeDummyEntities(preprocessed); + const thingtalk = serializePrediction(program, preprocessed, entities, { locale: this._locale, timezone: this._timezone, includeEntityValue : this._includeEntityValue }).join(' '); + annotated.push({ + id : annotated.length + 1, + raw: user.utterance, + preprocessed, + thingtalk + }); + } catch(e) { + console.log('Error during serializing:'); + console.log('question:', example.user.utterance); + console.log('triples:', example.system.active_set); + console.error(e.message); + error.push(example); + } + } + console.log(`${annotated.length} annotated, ${skipped.length} skipped, ${error.length} thrown error.`); + console.log(`Among skipped questions:`); + console.log(`(1) indirect questions: ${this._unsupportedCounter.indirect}`); + console.log(`(2) set operations: ${this._unsupportedCounter.setOp}`); + console.log(`(3) type constraint: ${this._unsupportedCounter.typeConstraint}`); + console.log(`(4) wrong annotation: ${this._unsupportedCounter.wrongAnnotation}`); + return annotated; + } + + async run() { + this._classDef = await loadClassDef(this._paths.thingpedia, { locale: this._locale, timezone: this._timezone }); + this._items = await readJson(this._paths.items); + this._values = await readJson(this._paths.values); + this._types = await readJson(this._paths.types); + this._wikidataProperties = await readJson(this._paths.wikidataProperties); + + // load in-domain examples + if (fs.existsSync(this._paths.filteredExamples)) + await this._loadFilteredExamples(); + else + await this._filterExamples(); + + // convert dataset annotation into thingtalk + const dataset = await this._convert(); + + // output thingtalk dataset + await util.promisify(fs.writeFile)(this._paths.output, dataset.map((example) => { + return `${example.id}\t${example.preprocessed}\t${example.thingtalk}`; + }).join('\n'), { encoding: 'utf8' }); + } +} + +module.exports = { + initArgparse(subparsers : argparse.SubParser) { + const parser = subparsers.add_parser('wikidata-convert-csqa', { + add_help: true, + description: "Generate parameter-datasets.tsv from processed wikidata dump. " + }); + parser.add_argument('-l', '--locale', { + default: 'en-US', + help: `BGP 47 locale tag of the natural language being processed (defaults to en-US).` + }); + parser.add_argument('--timezone', { + required: false, + default: undefined, + help: `Timezone to use to interpret dates and times (defaults to the current timezone).` + }); + parser.add_argument('-o', '--output', { + required: true, + }); + parser.add_argument('-i', '--input', { + required: true, + }); + parser.add_argument('--domains', { + required: true, + help: 'the path to the file containing type mapping for each domain' + }); + parser.add_argument('--thingpedia', { + required: true, + help: 'Path to ThingTalk file containing class definitions.' + }); + parser.add_argument('--wikidata-property-list', { + required: true, + help: "full list of properties in the wikidata dump, named filtered_property_wikidata4.json" + + "in CSQA, in the form of a dictionary with PID as keys and canonical as values." + }); + parser.add_argument('--items', { + required: true, + help: "A json file containing the labels for items of the domain" + }); + parser.add_argument('--values', { + required: true, + help: "A json file containing the labels for value entities for the domain" + }); + parser.add_argument('--types', { + required: true, + help: "A json file containing the entity types for value entities in the domain" + }); + parser.add_argument('--filtered-examples', { + required: true, + help: "A json file containing in-domain examples of the given CSQA dataset" + }); + parser.add_argument('--entity-id', { + action: 'store_true', + help: "Include entity id in thingtalk", + default: false + }); + parser.add_argument('--soft-match-id', { + action: 'store_true', + help: "Do string soft match on id property", + default: false + }); + parser.add_argument('--filter', { + required: false, + default: [], + nargs: '+', + help: 'filters to be applied to CSQA dataset, in the format of [key]=[value(int)]' + }); + }, + + async execute(args : any) { + const domains = new Domains({ path: args.domains }); + await domains.init(); + const csqaConverter = new CsqaConverter({ + locale: args.locale, + timezone : args.timezone, + domains, + inputDir: args.input, + output: args.output, + thingpedia: args.thingpedia, + wikidataProperties: args.wikidata_property_list, + items: args.items, + values: args.values, + types: args.types, + filteredExamples: args.filtered_examples, + includeEntityValue: args.entity_id, + softMatchId: args.soft_match_id, + filter: args.filter + }); + csqaConverter.run(); + }, + CsqaConverter +}; diff --git a/tool/autoqa/wikidata/csqa-type-mapper.ts b/tool/autoqa/wikidata/csqa-type-mapper.ts new file mode 100644 index 000000000..a2ff946ba --- /dev/null +++ b/tool/autoqa/wikidata/csqa-type-mapper.ts @@ -0,0 +1,317 @@ +// -*- mode: typescript; indent-tabs-mode: nil; js-basic-offset: 4 -*- +// +// This file is part of Genie +// +// Copyright 2021 The Board of Trustees of the Leland Stanford Junior University +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Silei Xu + + +import * as argparse from 'argparse'; +import * as fs from 'fs'; +import * as path from 'path'; +import assert from 'assert'; +import csvstringify from 'csv-stringify'; +import JSONStream from 'JSONStream'; +import * as StreamUtils from '../../../lib/utils/stream-utils'; +import { argnameFromLabel } from './utils'; +import { CSQADialogueTurn } from './csqa-converter'; + +const pfs = fs.promises; + +interface CSQATypeMapperOptions { + domains ?: string[], + input_dir : string, + output : string, + wikidata : string, + wikidata_labels : string, + minimum_appearance : number, + minimum_percentage : number, +} + +// map experiment name to CSQA type +const DOMAIN_MAP : Record = { + 'human': 'common_name', + 'city': 'administrative_territorial_entity', + 'country': 'designation_for_an_administrative_territorial_entity', + 'art': 'work_of_art', + 'song': 'release', + 'music_band': 'musical_ensemble', + 'game': 'application', + 'organization': 'organization', + 'disease': 'health_problem', + 'tv': 'television_program', + 'drug': 'drug' +}; + +class CSQATypeMapper { + private _inputDir : string; + private _output : string; + private _wikidata : string; + private _wikidataLabels : string; + private _minAppearance : number; + private _minPercentage : number; + private _domains ?: string[]; + private _labels : Map; + private _wikidataTypes : Map; + private _wikidataSuperTypes : Map; + private _typeMap : Map; + + + constructor(options : CSQATypeMapperOptions) { + this._inputDir = options.input_dir; + this._output = options.output; + this._wikidata = options.wikidata; + this._wikidataLabels = options.wikidata_labels; + this._minAppearance = options.minimum_appearance; + this._minPercentage = options.minimum_percentage; + this._domains = options.domains ? options.domains.map((domain : string) => DOMAIN_MAP[domain] || domain) : undefined; + + this._labels = new Map(); + this._wikidataTypes = new Map(); + this._wikidataSuperTypes = new Map(); + this._typeMap = new Map(); + } + + private async _loadKB(kbfile : string) { + const pipeline = fs.createReadStream(kbfile).pipe(JSONStream.parse('$*')); + pipeline.on('data', async (item) => { + const entity = item.key; + const predicates = item.value; + if ('P31' in predicates) { + const entityTypes = predicates['P31']; + this._wikidataTypes.set(entity, entityTypes); + for (const type of entityTypes) + this._labels.set(type, undefined); + } + if ('P279' in predicates) { + const superTypes = predicates['P279']; + this._wikidataSuperTypes.set(entity, superTypes); + for (const type of superTypes) + this._labels.set(type, undefined); + } + }); + + pipeline.on('error', (error) => console.error(error)); + await StreamUtils.waitEnd(pipeline); + } + + private async _loadLabels() { + const pipeline = fs.createReadStream(this._wikidataLabels).pipe(JSONStream.parse('$*')); + pipeline.on('data', async (entity) => { + const qid = String(entity.key); + const label = String(entity.value); + if (this._labels.has(qid)) + this._labels.set(qid, label); + }); + + pipeline.on('error', (error) => console.error(error)); + await StreamUtils.waitEnd(pipeline); + } + + async load() { + console.log('loading wikidata files ...'); + for (const kbfile of this._wikidata) + await this._loadKB(kbfile); + + console.log('loading wikidata labels ...'); + await this._loadLabels(); + } + + private _processDialog(dialog : CSQADialogueTurn[]) { + let userTurn, systemTurn; + for (const turn of dialog) { + if (turn.speaker === 'USER') { + userTurn = turn; + continue; + } + + assert(userTurn && turn.speaker === 'SYSTEM'); + systemTurn = turn; + + // extract examples from type 2.2.1, where an singular object-based question is asked. + // ie., given a relation and an object in the triple, asking for the subject + if (userTurn.ques_type_id === 2 && userTurn.sec_ques_type === 2 && userTurn.sec_ques_sub_type === 1) { + assert(userTurn.type_list && userTurn.type_list.length === 1); + const csqaType = userTurn.type_list[0]; + if (!this._typeMap.has(csqaType)) + this._typeMap.set(csqaType, { total: 0 }); + const answer = systemTurn.entities_in_utterance!; + for (const entity of answer) { + if (!this._wikidataTypes.has(entity)) + continue; + for (const type of this._wikidataTypes.get(entity) ?? []) { + const map = this._typeMap.get(csqaType); + map.total += 1; + if (!(type in map)) + map[type] = 1; + else + map[type] +=1; + } + } + } + } + } + + private async _processFile(file : string) { + const dialog = JSON.parse(await pfs.readFile(file, { encoding: 'utf8' })); + this._processDialog(dialog); + } + + private async _processDir(dir : string) { + const files = await pfs.readdir(dir); + for (const file of files) { + const fpath = path.join(dir, file); + const stats = await pfs.lstat(fpath); + if (stats.isDirectory() || stats.isSymbolicLink()) + await this._processDir(fpath); + else if (file.startsWith('QA') && file.endsWith('.json')) + await this._processFile(fpath); + } + } + + async process() { + console.log('start processing dialogs'); + await this._processDir(this._inputDir); + } + + + /** + * Output the type mapping in a tsv format, where each column shows + * 1. CSQA domain name (label of wikidata type) + * 2. CSQA domain wikidata type (QID) + * 3. the actual wikidata types for entities in the CSQA domain (filtered, used as the type map) + * 4. the actual wikidata types for entities in the CSQA domain (unfiltered, used as a reference only) + * + * each wikidata type in 3 and 4 is in the following format, and separated by space + * :