001/** 002 * Portions Copyright 2001 Sun Microsystems, Inc. 003 * Portions Copyright 1999-2001 Language Technologies Institute, 004 * Carnegie Mellon University. 005 * All Rights Reserved. Use is subject to license terms. 006 * 007 * See the file "license.terms" for information on usage and 008 * redistribution of this file, and for a DISCLAIMER OF ALL 009 * WARRANTIES. 010 */ 011package com.sun.speech.freetts.cart; 012 013import java.io.BufferedReader; 014import java.io.DataInputStream; 015import java.io.DataOutputStream; 016import java.io.IOException; 017import java.io.InputStreamReader; 018import java.io.PrintWriter; 019import java.net.URL; 020import java.nio.ByteBuffer; 021import java.util.StringTokenizer; 022import java.util.logging.Level; 023import java.util.logging.Logger; 024import java.util.regex.Pattern; 025 026import com.sun.speech.freetts.Item; 027import com.sun.speech.freetts.PathExtractor; 028import com.sun.speech.freetts.PathExtractorImpl; 029import com.sun.speech.freetts.util.Utilities; 030 031/** 032 * Implementation of a Classification and Regression Tree (CART) that is 033 * used more like a binary decision tree, with each node containing a 034 * decision or a final value. The decision nodes in the CART trees 035 * operate on an Item and have the following format: 036 * 037 * <pre> 038 * NODE feat operand value qfalse 039 * </pre> 040 * 041 * <p>Where <code>feat</code> is an string that represents a feature 042 * to pass to the <code>findFeature</code> method of an item. 043 * 044 * <p>The <code>value</code> represents the value to be compared against 045 * the feature obtained from the item via the <code>feat</code> string. 046 * The <code>operand</code> is the operation to do the comparison. The 047 * available operands are as follows: 048 * 049 * <ul> 050 * <li>< - the feature is less than value 051 * <li>= - the feature is equal to the value 052 * <li>> - the feature is greater than the value 053 * <li>MATCHES - the feature matches the regular expression stored in value 054 * <li>IN - [[[TODO: still guessing because none of the CART's in 055 * Flite seem to use IN]]] the value is in the list defined by the 056 * feature. 057 * </ul> 058 * 059 * <p>[[[TODO: provide support for the IN operator.]]] 060 * 061 * <p>For < and >, this CART coerces the value and feature to 062 * float's. For =, this CART coerces the value and feature to string and 063 * checks for string equality. For MATCHES, this CART uses the value as a 064 * regular expression and compares the obtained feature to that. 065 * 066 * <p>A CART is represented by an array in this implementation. The 067 * <code>qfalse</code> value represents the index of the array to go to if 068 * the comparison does not match. In this implementation, qtrue index 069 * is always implied, and represents the next element in the 070 * array. The root node of the CART is the first element in the array. 071 * 072 * <p>The interpretations always start at the root node of the CART 073 * and continue until a final node is found. The final nodes have the 074 * following form: 075 * 076 * <pre> 077 * LEAF value 078 * </pre> 079 * 080 * <p>Where <code>value</code> represents the value of the node. 081 * Reaching a final node indicates the interpretation is over and the 082 * value of the node is the interpretation result. 083 */ 084public class CARTImpl implements CART { 085 /** Logger instance. */ 086 private static final Logger LOGGER = 087 Logger.getLogger(CARTImpl.class.getName()); 088 /** 089 * Entry in file represents the total number of nodes in the 090 * file. This should be at the top of the file. The format 091 * should be "TOTAL n" where n is an integer value. 092 */ 093 final static String TOTAL = "TOTAL"; 094 095 /** 096 * Entry in file represents a node. The format should be 097 * "NODE feat op val f" where 'feat' represents a feature, op 098 * represents an operand, val is the value, and f is the index 099 * of the node to go to is there isn't a match. 100 */ 101 final static String NODE = "NODE"; 102 103 /** 104 * Entry in file represents a final node. The format should be 105 * "LEAF val" where val represents the value. 106 */ 107 final static String LEAF = "LEAF"; 108 109 /** 110 * OPERAND_MATCHES 111 */ 112 final static String OPERAND_MATCHES = "MATCHES"; 113 114 115 /** 116 * The CART. Entries can be DecisionNode or LeafNode. An 117 * ArrayList could be used here -- I chose not to because I 118 * thought it might be quicker to avoid dealing with the dynamic 119 * resizing. 120 */ 121 Node[] cart = null; 122 123 /** 124 * The number of nodes in the CART. 125 */ 126 transient int curNode = 0; 127 128 /** 129 * Creates a new CART by reading from the given URL. 130 * 131 * @param url the location of the CART data 132 * 133 * @throws IOException if errors occur while reading the data 134 */ 135 public CARTImpl(URL url) throws IOException { 136 BufferedReader reader; 137 String line; 138 139 reader = new BufferedReader(new InputStreamReader(url.openStream())); 140 line = reader.readLine(); 141 while (line != null) { 142 if (!line.startsWith("***")) { 143 parseAndAdd(line); 144 } 145 line = reader.readLine(); 146 } 147 reader.close(); 148 } 149 150 /** 151 * Creates a new CART by reading from the given reader. 152 * 153 * @param reader the source of the CART data 154 * @param nodes the number of nodes to read for this cart 155 * 156 * @throws IOException if errors occur while reading the data 157 */ 158 public CARTImpl(BufferedReader reader, int nodes) throws IOException { 159 this(nodes); 160 String line; 161 for (int i = 0; i < nodes; i++) { 162 line = reader.readLine(); 163 if (!line.startsWith("***")) { 164 parseAndAdd(line); 165 } 166 } 167 } 168 169 /** 170 * Creates a new CART that will be populated with nodes later. 171 * 172 * @param numNodes the number of nodes 173 */ 174 private CARTImpl(int numNodes) { 175 cart = new Node[numNodes]; 176 } 177 178 /** 179 * Dumps this CART to the output stream. 180 * 181 * @param os the output stream 182 * 183 * @throws IOException if an error occurs during output 184 */ 185 public void dumpBinary(DataOutputStream os) throws IOException { 186 os.writeInt(cart.length); 187 for (int i = 0; i < cart.length; i++) { 188 cart[i].dumpBinary(os); 189 } 190 } 191 192 /** 193 * Dump the CART tree as a dot file. 194 * <p> 195 * The dot tool is part of the graphviz distribution at 196 * <a href="http://www.graphviz.org/">http://www.graphviz.org/</a>. 197 * If installed, call it as "dot -O -Tpdf *.dot" from the console to 198 * generate pdfs. 199 * </p> 200 * @param out The PrintWriter to write to. 201 */ 202 public void dumpDot(PrintWriter out) { 203 out.write("digraph \"" + "CART Tree" + "\" {\n"); 204 out.write("rankdir = LR\n"); 205 206 for (Node n : cart) { 207 out.println("\tnode" + Math.abs(n.hashCode()) + " [ label=\"" 208 + n.toString() + "\", color=" + dumpDotNodeColor(n) 209 + ", shape=" + dumpDotNodeShape(n) + " ]\n"); 210 if (n instanceof DecisionNode) { 211 DecisionNode dn = (DecisionNode) n; 212 if (dn.qtrue < cart.length && cart[dn.qtrue] != null) { 213 out.write("\tnode" + Math.abs(n.hashCode()) + " -> node" 214 + Math.abs(cart[dn.qtrue].hashCode()) + " [ label=" 215 + "TRUE" + " ]\n"); 216 } 217 if (dn.qfalse < cart.length && cart[dn.qfalse] != null) { 218 out.write("\tnode" + Math.abs(n.hashCode()) + " -> node" 219 + Math.abs(cart[dn.qfalse].hashCode()) 220 + " [ label=" + "FALSE" + " ]\n"); 221 } 222 } 223 } 224 225 out.write("}\n"); 226 out.close(); 227 } 228 229 protected String dumpDotNodeColor(Node n) { 230 if (n instanceof LeafNode) { 231 return "green"; 232 } 233 return "red"; 234 } 235 236 protected String dumpDotNodeShape(Node n) { 237 return "box"; 238 } 239 240 /** 241 * Loads a CART from the input byte buffer. 242 * 243 * @param bb the byte buffer 244 * 245 * @return the CART 246 * 247 * @throws IOException if an error occurs during output 248 * 249 * Note that cart nodes are really saved as strings that 250 * have to be parsed. 251 */ 252 public static CART loadBinary(ByteBuffer bb) throws IOException { 253 int numNodes = bb.getInt(); 254 CARTImpl cart = new CARTImpl(numNodes); 255 256 for (int i = 0; i < numNodes; i++) { 257 String nodeCreationLine = Utilities.getString(bb); 258 cart.parseAndAdd(nodeCreationLine); 259 } 260 return cart; 261 } 262 263 /** 264 * Loads a CART from the input stream. 265 * 266 * @param is the input stream 267 * 268 * @return the CART 269 * 270 * @throws IOException if an error occurs during output 271 * 272 * Note that cart nodes are really saved as strings that 273 * have to be parsed. 274 */ 275 public static CART loadBinary(DataInputStream is) throws IOException { 276 int numNodes = is.readInt(); 277 CARTImpl cart = new CARTImpl(numNodes); 278 279 for (int i = 0; i < numNodes; i++) { 280 String nodeCreationLine = Utilities.getString(is); 281 cart.parseAndAdd(nodeCreationLine); 282 } 283 return cart; 284 } 285 286 /** 287 * Creates a node from the given input line and add it to the CART. 288 * It expects the TOTAL line to come before any of the nodes. 289 * 290 * @param line a line of input to parse 291 */ 292 protected void parseAndAdd(String line) { 293 StringTokenizer tokenizer = new StringTokenizer(line," "); 294 String type = tokenizer.nextToken(); 295 if (type.equals(LEAF) || type.equals(NODE)) { 296 cart[curNode] = getNode(type, tokenizer, curNode); 297 cart[curNode].setCreationLine(line); 298 curNode++; 299 } else if (type.equals(TOTAL)) { 300 cart = new Node[Integer.parseInt(tokenizer.nextToken())]; 301 curNode = 0; 302 } else { 303 throw new Error("Invalid CART type: " + type); 304 } 305 } 306 307 /** 308 * Gets the node based upon the type and tokenizer. 309 * 310 * @param type <code>NODE</code> or <code>LEAF</code> 311 * @param tokenizer the StringTokenizer containing the data to get 312 * @param currentNode the index of the current node we're looking at 313 * 314 * @return the node 315 */ 316 protected Node getNode(String type, 317 StringTokenizer tokenizer, 318 int currentNode) { 319 if (type.equals(NODE)) { 320 String feature = tokenizer.nextToken(); 321 String operand = tokenizer.nextToken(); 322 Object value = parseValue(tokenizer.nextToken()); 323 int qfalse = Integer.parseInt(tokenizer.nextToken()); 324 if (operand.equals(OPERAND_MATCHES)) { 325 return new MatchingNode(feature, 326 value.toString(), 327 currentNode + 1, 328 qfalse); 329 } else { 330 return new ComparisonNode(feature, 331 value, 332 operand, 333 currentNode + 1, 334 qfalse); 335 } 336 } else if (type.equals(LEAF)) { 337 return new LeafNode(parseValue(tokenizer.nextToken())); 338 } 339 340 return null; 341 } 342 343 /** 344 * Coerces a string into a value. 345 * 346 * @param string of the form "type(value)"; for example, "Float(2.3)" 347 * 348 * @return the value 349 */ 350 protected Object parseValue(String string) { 351 int openParen = string.indexOf("("); 352 String type = string.substring(0,openParen); 353 String value = string.substring(openParen + 1, string.length() - 1); 354 if (type.equals("String")) { 355 return value; 356 } else if (type.equals("Float")) { 357 return new Float(Float.parseFloat(value)); 358 } else if (type.equals("Integer")) { 359 return new Integer(Integer.parseInt(value)); 360 } else if (type.equals("List")) { 361 StringTokenizer tok = new StringTokenizer(value, ","); 362 int size = tok.countTokens(); 363 364 int[] values = new int[size]; 365 for (int i = 0; i < size; i++) { 366 float fval = Float.parseFloat(tok.nextToken()); 367 values[i] = Math.round(fval); 368 } 369 return values; 370 } else { 371 throw new Error("Unknown type: " + type); 372 } 373 } 374 375 /** 376 * Passes the given item through this CART and returns the 377 * interpretation. 378 * 379 * @param item the item to analyze 380 * 381 * @return the interpretation 382 */ 383 public Object interpret(Item item) { 384 int nodeIndex = 0; 385 DecisionNode decision; 386 387 while (!(cart[nodeIndex] instanceof LeafNode)) { 388 decision = (DecisionNode) cart[nodeIndex]; 389 nodeIndex = decision.getNextNode(item); 390 } 391 if (LOGGER.isLoggable(Level.FINER)) { 392 LOGGER.finer("LEAF " + cart[nodeIndex].getValue()); 393 } 394 return ((LeafNode) cart[nodeIndex]).getValue(); 395 } 396 397 /** 398 * A node for the CART. 399 */ 400 static abstract class Node { 401 /** 402 * The value of this node. 403 */ 404 protected Object value; 405 private String creationLine; 406 407 /** 408 * Create a new Node with the given value. 409 */ 410 public Node(Object value) { 411 this.value = value; 412 } 413 414 /** 415 * Get the value. 416 */ 417 public Object getValue() { 418 return value; 419 } 420 421 /** 422 * Return a string representation of the type of the value. 423 */ 424 public String getValueString() { 425 if (value == null) { 426 return "NULL()"; 427 } else if (value instanceof String) { 428 return "String(" + value.toString() + ")"; 429 } else if (value instanceof Float) { 430 return "Float(" + value.toString() + ")"; 431 } else if (value instanceof Integer) { 432 return "Integer(" + value.toString() + ")"; 433 } else { 434 return value.getClass().toString() + "(" + value.toString() + ")"; 435 } 436 } 437 438 /** 439 * sets the line of text used to create this node. 440 * @param line the creation line 441 */ 442 public void setCreationLine(String line) { 443 creationLine = line; 444 } 445 446 /** 447 * Dumps the binary form of this node. 448 * @param os the output stream to output the node on 449 * @throws IOException if an IO error occurs 450 */ 451 final public void dumpBinary(DataOutputStream os) throws IOException { 452 Utilities.outString(os, creationLine); 453 } 454 } 455 456 /** 457 * A decision node that determines the next Node to go to in the CART. 458 */ 459 abstract static class DecisionNode extends Node { 460 /** 461 * The feature used to find a value from an Item. 462 */ 463 private PathExtractor path; 464 465 /** 466 * Index of Node to go to if the comparison doesn't match. 467 */ 468 protected int qfalse; 469 470 /** 471 * Index of Node to go to if the comparison matches. 472 */ 473 protected int qtrue; 474 475 /** 476 * The feature used to find a value from an Item. 477 */ 478 public String getFeature() { 479 return path.toString(); 480 } 481 482 483 /** 484 * Find the feature associated with this DecisionNode 485 * and the given item 486 * @param item the item to start from 487 * @return the object representing the feature 488 */ 489 public Object findFeature(Item item) { 490 return path.findFeature(item); 491 } 492 493 494 /** 495 * Returns the next node based upon the 496 * descision determined at this node 497 * @param item the current item. 498 * @return the index of the next node 499 */ 500 public final int getNextNode(Item item) { 501 return getNextNode(findFeature(item)); 502 } 503 504 /** 505 * Create a new DecisionNode. 506 * @param feature the string used to get a value from an Item 507 * @param value the value to compare to 508 * @param qtrue the Node index to go to if the comparison matches 509 * @param qfalse the Node machine index to go to upon no match 510 */ 511 public DecisionNode(String feature, 512 Object value, 513 int qtrue, 514 int qfalse) { 515 super(value); 516 this.path = new PathExtractorImpl(feature, true); 517 this.qtrue = qtrue; 518 this.qfalse = qfalse; 519 } 520 521 /** 522 * Get the next Node to go to in the CART. The return 523 * value is an index in the CART. 524 */ 525 abstract public int getNextNode(Object val); 526 } 527 528 /** 529 * A decision Node that compares two values. 530 */ 531 static class ComparisonNode extends DecisionNode { 532 /** 533 * LESS_THAN 534 */ 535 final static String LESS_THAN = "<"; 536 537 /** 538 * EQUALS 539 */ 540 final static String EQUALS = "="; 541 542 /** 543 * GREATER_THAN 544 */ 545 final static String GREATER_THAN = ">"; 546 547 /** 548 * The comparison type. One of LESS_THAN, GREATER_THAN, or 549 * EQUAL_TO. 550 */ 551 String comparisonType; 552 553 /** 554 * Create a new ComparisonNode with the given values. 555 * @param feature the string used to get a value from an Item 556 * @param value the value to compare to 557 * @param comparisonType one of LESS_THAN, EQUAL_TO, or GREATER_THAN 558 * @param qtrue the Node index to go to if the comparison matches 559 * @param qfalse the Node index to go to upon no match 560 */ 561 public ComparisonNode(String feature, 562 Object value, 563 String comparisonType, 564 int qtrue, 565 int qfalse) { 566 super(feature, value, qtrue, qfalse); 567 if (!comparisonType.equals(LESS_THAN) 568 && !comparisonType.equals(EQUALS) 569 && !comparisonType.equals(GREATER_THAN)) { 570 throw new Error("Invalid comparison type: " + comparisonType); 571 } else { 572 this.comparisonType = comparisonType; 573 } 574 } 575 576 /** 577 * Compare the given value and return the appropriate Node index. 578 * IMPLEMENTATION NOTE: LESS_THAN and GREATER_THAN, the Node's 579 * value and the value passed in are converted to floating point 580 * values. For EQUAL, the Node's value and the value passed in 581 * are treated as String compares. This is the way of Flite, so 582 * be it Flite. 583 * @param val the value to compare 584 */ 585 public int getNextNode(Object val) { 586 boolean yes = false; 587 int ret; 588 589 if (comparisonType.equals(LESS_THAN) 590 || comparisonType.equals(GREATER_THAN)) { 591 float cart_fval; 592 float fval; 593 if (value instanceof Float) { 594 cart_fval = ((Float) value).floatValue(); 595 } else { 596 cart_fval = Float.parseFloat(value.toString()); 597 } 598 if (val instanceof Float) { 599 fval = ((Float) val).floatValue(); 600 } else { 601 fval = Float.parseFloat(val.toString()); 602 } 603 if (comparisonType.equals(LESS_THAN)) { 604 yes = (fval < cart_fval); 605 } else { 606 yes = (fval > cart_fval); 607 } 608 } else { // comparisonType = "=" 609 String sval = val.toString(); 610 String cart_sval = value.toString(); 611 yes = sval.equals(cart_sval); 612 } 613 if (yes) { 614 ret = qtrue; 615 } else { 616 ret = qfalse; 617 } 618 619 if (LOGGER.isLoggable(Level.FINER)) { 620 LOGGER.finer(trace(val, yes, ret)); 621 } 622 623 return ret; 624 } 625 626 private String trace(Object value, boolean match, int next) { 627 return 628 "NODE " + getFeature() + " [" 629 + value + "] " 630 + comparisonType + " [" 631 + getValue() + "] " 632 + (match ? "Yes" : "No") + " next " + 633 next; 634 } 635 636 /** 637 * Get a string representation of this Node. 638 */ 639 public String toString() { 640 return 641 "NODE " + getFeature() + " " 642 + comparisonType + " " 643 + getValueString() + " " 644 + Integer.toString(qtrue) + " " 645 + Integer.toString(qfalse); 646 } 647 } 648 649 /** 650 * A Node that checks for a regular expression match. 651 */ 652 static class MatchingNode extends DecisionNode { 653 Pattern pattern; 654 655 /** 656 * Create a new MatchingNode with the given values. 657 * @param feature the string used to get a value from an Item 658 * @param regex the regular expression 659 * @param qtrue the Node index to go to if the comparison matches 660 * @param qfalse the Node index to go to upon no match 661 */ 662 public MatchingNode(String feature, 663 String regex, 664 int qtrue, 665 int qfalse) { 666 super(feature, regex, qtrue, qfalse); 667 this.pattern = Pattern.compile(regex); 668 } 669 670 /** 671 * Compare the given value and return the appropriate CART index. 672 * @param val the value to compare -- this must be a String 673 */ 674 public int getNextNode(Object val) { 675 return pattern.matcher((String) val).matches() 676 ? qtrue 677 : qfalse; 678 } 679 680 /** 681 * Get a string representation of this Node. 682 */ 683 public String toString() { 684 StringBuffer buf = new StringBuffer( 685 NODE + " " + getFeature() + " " + OPERAND_MATCHES); 686 buf.append(getValueString() + " "); 687 buf.append(Integer.toString(qtrue) + " "); 688 buf.append(Integer.toString(qfalse)); 689 return buf.toString(); 690 } 691 } 692 693 /** 694 * The final Node of a CART. This just a marker class. 695 */ 696 static class LeafNode extends Node { 697 /** 698 * Create a new LeafNode with the given value. 699 * @param the value of this LeafNode 700 */ 701 public LeafNode(Object value) { 702 super(value); 703 } 704 705 /** 706 * Get a string representation of this Node. 707 */ 708 public String toString() { 709 return "LEAF " + getValueString(); 710 } 711 } 712} 713